bit of cleaning up
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include <math.h>
33 #include <stdlib.h>
34 #include "cwrs.h"
35 #include "vq.h"
36
37
38 /* Improved algebraic pulse-base quantiser. The signal x is replaced by the sum of the pitch 
39    a combination of pulses such that its norm is still equal to 1. The only difference with 
40    the quantiser above is that the search is more complete. */
41 void alg_quant(float *x, float *W, int N, int K, float *p, float alpha, ec_enc *enc)
42 {
43    int L = 3;
44    //float tata[200];
45    float _y[L][N];
46    int _iy[L][N];
47    //float tata2[200];
48    float _ny[L][N];
49    int _iny[L][N];
50    float *(ny[L]), *(y[L]);
51    int *(iny[L]), *(iy[L]);
52    int i, j, m;
53    int pulsesLeft;
54    float xy[L], nxy[L];
55    float yy[L], nyy[L];
56    float yp[L], nyp[L];
57    float best_scores[L];
58    float Rpp=0, Rxp=0;
59    float gain[L];
60    int maxL = 1;
61    
62    for (m=0;m<L;m++)
63    {
64       ny[m] = _ny[m];
65       iny[m] = _iny[m];
66       y[m] = _y[m];
67       iy[m] = _iy[m];
68    }
69    
70    for (j=0;j<N;j++)
71       Rpp += p[j]*p[j];
72    //if (Rpp>.01)
73    //   alpha = (1-sqrt(1-Rpp))/Rpp;
74    for (j=0;j<N;j++)
75       Rxp += x[j]*p[j];
76    for (m=0;m<L;m++)
77       for (i=0;i<N;i++)
78          y[m][i] = 0;
79       
80    for (m=0;m<L;m++)
81       for (i=0;i<N;i++)
82          ny[m][i] = 0;
83
84    for (m=0;m<L;m++)
85       for (i=0;i<N;i++)
86          iy[m][i] = iny[m][i] = 0;
87
88    for (m=0;m<L;m++)
89       xy[m] = yy[m] = yp[m] = gain[m] = 0;
90    
91    pulsesLeft = K;
92    while (pulsesLeft > 0)
93    {
94       int pulsesAtOnce=1;
95       int Lupdate = L;
96       int L2 = L;
97       pulsesAtOnce = pulsesLeft/N;
98       if (pulsesAtOnce<1)
99          pulsesAtOnce = 1;
100       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
101          Lupdate = 1;
102       //printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);
103       L2 = Lupdate;
104       if (L2>maxL)
105       {
106          L2 = maxL;
107          maxL *= N;
108       }
109
110       for (m=0;m<L;m++)
111          best_scores[m] = -1e10;
112
113       for (m=0;m<L2;m++)
114       {
115          for (j=0;j<N;j++)
116          {
117             int sign;
118             for (sign=-1;sign<=1;sign+=2)
119             {
120                /* All pulses at one location must have the same size */
121                if (iy[m][j]*sign < 0)
122                   continue;
123                //fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);
124                float tmp_xy, tmp_yy, tmp_yp;
125                float score;
126                float g;
127                float s = sign*pulsesAtOnce;
128                
129                /* Updating the sums the the new pulse(s) */
130                tmp_xy = xy[m] + s*x[j]               - alpha*s*p[j]*Rxp;
131                tmp_yy = yy[m] + 2*s*y[m][j] + s*s      +s*s*alpha*alpha*p[j]*p[j]*Rpp - 2*alpha*s*p[j]*yp[m] - 2*s*s*alpha*p[j]*p[j];
132                tmp_yp = yp[m] + s*p[j]               *(1-alpha*Rpp);
133                g = (sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)/tmp_yy;
134                score = 2*g*tmp_xy - g*g*tmp_yy;
135
136                if (score>best_scores[Lupdate-1])
137                {
138                   int k, n;
139                   int id = Lupdate-1;
140                   float *tmp_ny;
141                   int *tmp_iny;
142                   
143                   tmp_ny = ny[Lupdate-1];
144                   tmp_iny = iny[Lupdate-1];
145                   while (id > 0 && score > best_scores[id-1])
146                      id--;
147                
148                   for (k=Lupdate-1;k>id;k--)
149                   {
150                      nxy[k] = nxy[k-1];
151                      nyy[k] = nyy[k-1];
152                      nyp[k] = nyp[k-1];
153                      //fprintf(stderr, "%d %d \n", N, k);
154                      ny[k] = ny[k-1];
155                      iny[k] = iny[k-1];
156                      gain[k] = gain[k-1];
157                      best_scores[k] = best_scores[k-1];
158                   }
159
160                   ny[id] = tmp_ny;
161                   iny[id] = tmp_iny;
162
163                   nxy[id] = tmp_xy;
164                   nyy[id] = tmp_yy;
165                   nyp[id] = tmp_yp;
166                   gain[id] = g;
167                   for (n=0;n<N;n++)
168                      ny[id][n] = y[m][n] - alpha*s*p[j]*p[n];
169                   ny[id][j] += s;
170
171                   for (n=0;n<N;n++)
172                      iny[id][n] = iy[m][n];
173                   if (s>0)
174                      iny[id][j] += pulsesAtOnce;
175                   else
176                      iny[id][j] -= pulsesAtOnce;
177                   best_scores[id] = score;
178                }
179             }
180          }
181
182       }
183       int k;
184       for (k=0;k<Lupdate;k++)
185       {
186          float *tmp_ny;
187          int *tmp_iny;
188
189          xy[k] = nxy[k];
190          yy[k] = nyy[k];
191          yp[k] = nyp[k];
192          
193          tmp_ny = ny[k];
194          ny[k] = y[k];
195          y[k] = tmp_ny;
196          tmp_iny = iny[k];
197          iny[k] = iy[k];
198          iy[k] = tmp_iny;
199       }
200       pulsesLeft -= pulsesAtOnce;
201    }
202    
203    for (i=0;i<N;i++)
204       x[i] = p[i]+gain[0]*y[0][i];
205    if (0) {
206       float E=1e-15;
207       int ABS = 0;
208       for (i=0;i<N;i++)
209          ABS += abs(iy[0][i]);
210       //if (K != ABS)
211       //   printf ("%d %d\n", K, ABS);
212       for (i=0;i<N;i++)
213          E += x[i]*x[i];
214       //printf ("%f\n", E);
215       E = 1/sqrt(E);
216       for (i=0;i<N;i++)
217          x[i] *= E;
218    }
219    int comb[K];
220    int signs[K];
221    //for (i=0;i<N;i++)
222    //   printf ("%d ", iy[0][i]);
223    pulse2comb(N, K, comb, signs, iy[0]); 
224    ec_enc_uint64(enc,icwrs64(N, K, comb, signs),ncwrs64(N, K));
225    //printf ("%llu ", icwrs64(N, K, comb, signs));
226    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
227       due to the recursive computation used in quantisation.
228       Not quite sure whether we need that or not */
229    if (1) {
230       float Ryp=0;
231       float Rpp=0;
232       float Ryy=0;
233       float g=0;
234       for (i=0;i<N;i++)
235          Rpp += p[i]*p[i];
236       
237       for (i=0;i<N;i++)
238          Ryp += iy[0][i]*p[i];
239       
240       for (i=0;i<N;i++)
241          y[0][i] = iy[0][i] - alpha*Ryp*p[i];
242       
243       Ryp = 0;
244       for (i=0;i<N;i++)
245          Ryp += y[0][i]*p[i];
246       
247       for (i=0;i<N;i++)
248          Ryy += y[0][i]*y[0][i];
249       
250       g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
251         
252       for (i=0;i<N;i++)
253          x[i] = p[i] + g*y[0][i];
254       
255    }
256
257 }
258
259 void alg_unquant(float *x, int N, int K, float *p, float alpha, ec_dec *dec)
260 {
261    int i;
262    celt_uint64_t id;
263    int comb[K];
264    int signs[K];
265    int iy[N];
266    float y[N];
267    float Rpp=0, Ryp=0, Ryy=0;
268    float g;
269
270    id = ec_dec_uint64(dec, ncwrs64(N, K));
271    //printf ("%llu ", id);
272    cwrsi64(N, K, id, comb, signs);
273    comb2pulse(N, K, iy, comb, signs);
274    //for (i=0;i<N;i++)
275    //   printf ("%d ", iy[i]);
276    for (i=0;i<N;i++)
277       Rpp += p[i]*p[i];
278
279    for (i=0;i<N;i++)
280       Ryp += iy[i]*p[i];
281
282    for (i=0;i<N;i++)
283       y[i] = iy[i] - alpha*Ryp*p[i];
284
285    /* Recompute after the projection (I think it's right) */
286    Ryp = 0;
287    for (i=0;i<N;i++)
288       Ryp += y[i]*p[i];
289
290    for (i=0;i<N;i++)
291       Ryy += y[i]*y[i];
292
293    g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
294
295    for (i=0;i<N;i++)
296       x[i] = p[i] + g*y[i];
297 }
298
299
300 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
301
302 void intra_prediction(float *x, float *W, int N, int K, float *Y, float *P, int B, int N0, ec_enc *enc)
303 {
304    int i,j;
305    int best=0;
306    float best_score=0;
307    float s = 1;
308    int sign;
309    float E;
310    int max_pos = N0-N/B;
311    if (max_pos > 32)
312       max_pos = 32;
313
314    for (i=0;i<max_pos*B;i+=B)
315    {
316       int j;
317       float xy=0, yy=0;
318       float score;
319       for (j=0;j<N;j++)
320       {
321          xy += x[j]*Y[i+j];
322          yy += Y[i+j]*Y[i+j];
323       }
324       score = xy*xy/(.001+yy);
325       if (score > best_score)
326       {
327          best_score = score;
328          best = i;
329          if (xy>0)
330             s = 1;
331          else
332             s = -1;
333       }
334    }
335    if (s<0)
336       sign = 1;
337    else
338       sign = 0;
339    //printf ("%d %d ", sign, best);
340    ec_enc_uint(enc,sign,2);
341    ec_enc_uint(enc,best/B,max_pos);
342    //printf ("%d %f\n", best, best_score);
343    
344    float pred_gain;
345    if (K>10)
346       pred_gain = pg[10];
347    else
348       pred_gain = pg[K];
349    E = 1e-10;
350    for (j=0;j<N;j++)
351    {
352       P[j] = s*Y[best+j];
353       E += P[j]*P[j];
354    }
355    E = pred_gain/sqrt(E);
356    for (j=0;j<N;j++)
357       P[j] *= E;
358    if (K>0)
359    {
360       for (j=0;j<N;j++)
361          x[j] -= P[j];
362    } else {
363       for (j=0;j<N;j++)
364          x[j] = P[j];
365    }
366    //printf ("quant ");
367    //for (j=0;j<N;j++) printf ("%f ", P[j]);
368
369 }
370
371 void intra_unquant(float *x, int N, int K, float *Y, float *P, int B, int N0, ec_dec *dec)
372 {
373    int j;
374    int sign;
375    float s;
376    int best;
377    float E;
378    int max_pos = N0-N/B;
379    if (max_pos > 32)
380       max_pos = 32;
381    
382    sign = ec_dec_uint(dec, 2);
383    if (sign == 0)
384       s = 1;
385    else
386       s = -1;
387    
388    best = B*ec_dec_uint(dec, max_pos);
389    //printf ("%d %d ", sign, best);
390
391    float pred_gain;
392    if (K>10)
393       pred_gain = pg[10];
394    else
395       pred_gain = pg[K];
396    E = 1e-10;
397    for (j=0;j<N;j++)
398    {
399       P[j] = s*Y[best+j];
400       E += P[j]*P[j];
401    }
402    E = pred_gain/sqrt(E);
403    for (j=0;j<N;j++)
404       P[j] *= E;
405    if (K==0)
406    {
407       for (j=0;j<N;j++)
408          x[j] = P[j];
409    }
410 }
411
412 void intra_fold(float *x, int N, int K, float *Y, float *P, int B, int N0)
413 {
414    int j;
415    float E;
416    
417    E = 1e-10;
418    for (j=0;j<N;j++)
419    {
420       P[j] = Y[j];
421       E += P[j]*P[j];
422    }
423    E = 1.f/sqrt(E);
424    for (j=0;j<N;j++)
425       P[j] *= E;
426    for (j=0;j<N;j++)
427       x[j] = P[j];
428 }
429