Copying pointers is faster than copying arrays (who knew!).
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include <math.h>
33 #include <stdlib.h>
34 #include "cwrs.h"
35 #include "vq.h"
36
37
38 /* Improved algebraic pulse-base quantiser. The signal x is replaced by the sum of the pitch 
39    a combination of pulses such that its norm is still equal to 1. The only difference with 
40    the quantiser above is that the search is more complete. */
41 void alg_quant(float *x, float *W, int N, int K, float *p, float alpha, ec_enc *enc)
42 {
43    int L = 3;
44    //float tata[200];
45    float _y[L][N];
46    int _iy[L][N];
47    //float tata2[200];
48    float _ny[L][N];
49    int _iny[L][N];
50    float *(ny[L]), *(y[L]);
51    int *(iny[L]), *(iy[L]);
52    int i, j, m;
53    int pulsesLeft;
54    float xy[L], nxy[L];
55    float yy[L], nyy[L];
56    float yp[L], nyp[L];
57    float best_scores[L];
58    float Rpp=0, Rxp=0;
59    float gain[L];
60    int maxL = 1;
61    
62    for (m=0;m<L;m++)
63    {
64       ny[m] = _ny[m];
65       iny[m] = _iny[m];
66       y[m] = _y[m];
67       iy[m] = _iy[m];
68    }
69    
70    for (j=0;j<N;j++)
71       Rpp += p[j]*p[j];
72    //if (Rpp>.01)
73    //   alpha = (1-sqrt(1-Rpp))/Rpp;
74    for (j=0;j<N;j++)
75       Rxp += x[j]*p[j];
76    for (m=0;m<L;m++)
77       for (i=0;i<N;i++)
78          y[m][i] = 0;
79       
80    for (m=0;m<L;m++)
81       for (i=0;i<N;i++)
82          ny[m][i] = 0;
83
84    for (m=0;m<L;m++)
85       for (i=0;i<N;i++)
86          iy[m][i] = iny[m][i] = 0;
87
88    for (m=0;m<L;m++)
89       xy[m] = yy[m] = yp[m] = gain[m] = 0;
90    
91    pulsesLeft = K;
92    while (pulsesLeft > 0)
93    {
94       int pulsesAtOnce=1;
95       int L2 = L;
96       if (L>maxL)
97       {
98          L2 = maxL;
99          maxL *= N;
100       }
101       if (pulsesLeft > 5)
102          L2 = 1;
103       
104       pulsesAtOnce = pulsesLeft/N;
105       if (pulsesAtOnce<1)
106          pulsesAtOnce = 1;
107
108       for (m=0;m<L;m++)
109          best_scores[m] = -1e10;
110
111       for (m=0;m<L2;m++)
112       {
113          for (j=0;j<N;j++)
114          {
115             int sign;
116             for (sign=-1;sign<=1;sign+=2)
117             {
118                /* All pulses at one location must have the same size */
119                if (iy[m][j]*sign < 0)
120                   continue;
121                //fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);
122                float tmp_xy, tmp_yy, tmp_yp;
123                float score;
124                float g;
125                float s = sign*pulsesAtOnce;
126                
127                /* Updating the sums the the new pulse(s) */
128                tmp_xy = xy[m] + s*x[j]               - alpha*s*p[j]*Rxp;
129                tmp_yy = yy[m] + 2*s*y[m][j] + s*s      +s*s*alpha*alpha*p[j]*p[j]*Rpp - 2*alpha*s*p[j]*yp[m] - 2*s*s*alpha*p[j]*p[j];
130                tmp_yp = yp[m] + s*p[j]               *(1-alpha*Rpp);
131                g = (sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)/tmp_yy;
132                score = 2*g*tmp_xy - g*g*tmp_yy;
133
134                if (score>best_scores[L-1])
135                {
136                   int k, n;
137                   int id = L-1;
138                   float *tmp_ny;
139                   int *tmp_iny;
140                   
141                   tmp_ny = ny[L-1];
142                   tmp_iny = iny[L-1];
143                   while (id > 0 && score > best_scores[id-1])
144                      id--;
145                
146                   for (k=L-1;k>id;k--)
147                   {
148                      nxy[k] = nxy[k-1];
149                      nyy[k] = nyy[k-1];
150                      nyp[k] = nyp[k-1];
151                      //fprintf(stderr, "%d %d \n", N, k);
152                      ny[k] = ny[k-1];
153                      iny[k] = iny[k-1];
154                      gain[k] = gain[k-1];
155                      best_scores[k] = best_scores[k-1];
156                   }
157
158                   ny[id] = tmp_ny;
159                   iny[id] = tmp_iny;
160                   
161                   nxy[id] = tmp_xy;
162                   nyy[id] = tmp_yy;
163                   nyp[id] = tmp_yp;
164                   gain[id] = g;
165                   for (n=0;n<N;n++)
166                      ny[id][n] = y[m][n];
167                   ny[id][j] += s;
168                   for (n=0;n<N;n++)
169                      ny[id][n] -= alpha*s*p[j]*p[n];
170                
171                   for (n=0;n<N;n++)
172                      iny[id][n] = iy[m][n];
173                   if (s>0)
174                      iny[id][j] += pulsesAtOnce;
175                   else
176                      iny[id][j] -= pulsesAtOnce;
177                   best_scores[id] = score;
178                }
179             }   
180          }
181          
182       }
183       int k;
184       for (k=0;k<L;k++)
185       {
186          float *tmp_ny;
187          int *tmp_iny;
188
189          xy[k] = nxy[k];
190          yy[k] = nyy[k];
191          yp[k] = nyp[k];
192          
193          tmp_ny = ny[k];
194          ny[k] = y[k];
195          y[k] = tmp_ny;
196          tmp_iny = iny[k];
197          iny[k] = iy[k];
198          iy[k] = tmp_iny;
199       }
200       pulsesLeft -= pulsesAtOnce;
201    }
202    
203    for (i=0;i<N;i++)
204       x[i] = p[i]+gain[0]*y[0][i];
205    if (0) {
206       float E=1e-15;
207       int ABS = 0;
208       for (i=0;i<N;i++)
209          ABS += abs(iy[0][i]);
210       //if (K != ABS)
211       //   printf ("%d %d\n", K, ABS);
212       for (i=0;i<N;i++)
213          E += x[i]*x[i];
214       //printf ("%f\n", E);
215       E = 1/sqrt(E);
216       for (i=0;i<N;i++)
217          x[i] *= E;
218    }
219    int comb[K];
220    int signs[K];
221    //for (i=0;i<N;i++)
222    //   printf ("%d ", iy[0][i]);
223    pulse2comb(N, K, comb, signs, iy[0]); 
224    ec_enc_uint64(enc,icwrs64(N, K, comb, signs),ncwrs64(N, K));
225    //printf ("%llu ", icwrs64(N, K, comb, signs));
226    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
227       due to the recursive computation used in quantisation */
228    if (1) {
229       float Ryp=0;
230       float Rpp=0;
231       float Ryy=0;
232       float g=0;
233       for (i=0;i<N;i++)
234          Rpp += p[i]*p[i];
235       
236       for (i=0;i<N;i++)
237          Ryp += iy[0][i]*p[i];
238       
239       for (i=0;i<N;i++)
240          y[0][i] = iy[0][i] - alpha*Ryp*p[i];
241       
242       Ryp = 0;
243       for (i=0;i<N;i++)
244          Ryp += y[0][i]*p[i];
245       
246       for (i=0;i<N;i++)
247          Ryy += y[0][i]*y[0][i];
248       
249       g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
250         
251       for (i=0;i<N;i++)
252          x[i] = p[i] + g*y[0][i];
253       
254    }
255
256 }
257
258 void alg_unquant(float *x, int N, int K, float *p, float alpha, ec_dec *dec)
259 {
260    int i;
261    celt_uint64_t id;
262    int comb[K];
263    int signs[K];
264    int iy[N];
265    float y[N];
266    float Rpp=0, Ryp=0, Ryy=0;
267    float g;
268
269    id = ec_dec_uint64(dec, ncwrs64(N, K));
270    //printf ("%llu ", id);
271    cwrsi64(N, K, id, comb, signs);
272    comb2pulse(N, K, iy, comb, signs);
273    //for (i=0;i<N;i++)
274    //   printf ("%d ", iy[i]);
275    for (i=0;i<N;i++)
276       Rpp += p[i]*p[i];
277
278    for (i=0;i<N;i++)
279       Ryp += iy[i]*p[i];
280
281    for (i=0;i<N;i++)
282       y[i] = iy[i] - alpha*Ryp*p[i];
283
284    /* Recompute after the projection (I think it's right) */
285    Ryp = 0;
286    for (i=0;i<N;i++)
287       Ryp += y[i]*p[i];
288
289    for (i=0;i<N;i++)
290       Ryy += y[i]*y[i];
291
292    g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
293
294    for (i=0;i<N;i++)
295       x[i] = p[i] + g*y[i];
296 }
297
298
299 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
300
301 void intra_prediction(float *x, float *W, int N, int K, float *Y, float *P, int B, int N0, ec_enc *enc)
302 {
303    int i,j;
304    int best=0;
305    float best_score=0;
306    float s = 1;
307    int sign;
308    float E;
309    int max_pos = N0-N/B;
310    if (max_pos > 32)
311       max_pos = 32;
312
313    for (i=0;i<max_pos*B;i+=B)
314    {
315       int j;
316       float xy=0, yy=0;
317       float score;
318       for (j=0;j<N;j++)
319       {
320          xy += x[j]*Y[i+j];
321          yy += Y[i+j]*Y[i+j];
322       }
323       score = xy*xy/(.001+yy);
324       if (score > best_score)
325       {
326          best_score = score;
327          best = i;
328          if (xy>0)
329             s = 1;
330          else
331             s = -1;
332       }
333    }
334    if (s<0)
335       sign = 1;
336    else
337       sign = 0;
338    //printf ("%d %d ", sign, best);
339    ec_enc_uint(enc,sign,2);
340    ec_enc_uint(enc,best/B,max_pos);
341    //printf ("%d %f\n", best, best_score);
342    
343    float pred_gain;
344    if (K>10)
345       pred_gain = pg[10];
346    else
347       pred_gain = pg[K];
348    E = 1e-10;
349    for (j=0;j<N;j++)
350    {
351       P[j] = s*Y[best+j];
352       E += P[j]*P[j];
353    }
354    E = pred_gain/sqrt(E);
355    for (j=0;j<N;j++)
356       P[j] *= E;
357    if (K>0)
358    {
359       for (j=0;j<N;j++)
360          x[j] -= P[j];
361    } else {
362       for (j=0;j<N;j++)
363          x[j] = P[j];
364    }
365    //printf ("quant ");
366    //for (j=0;j<N;j++) printf ("%f ", P[j]);
367
368 }
369
370 void intra_unquant(float *x, int N, int K, float *Y, float *P, int B, int N0, ec_dec *dec)
371 {
372    int j;
373    int sign;
374    float s;
375    int best;
376    float E;
377    int max_pos = N0-N/B;
378    if (max_pos > 32)
379       max_pos = 32;
380    
381    sign = ec_dec_uint(dec, 2);
382    if (sign == 0)
383       s = 1;
384    else
385       s = -1;
386    
387    best = B*ec_dec_uint(dec, max_pos);
388    //printf ("%d %d ", sign, best);
389
390    float pred_gain;
391    if (K>10)
392       pred_gain = pg[10];
393    else
394       pred_gain = pg[K];
395    E = 1e-10;
396    for (j=0;j<N;j++)
397    {
398       P[j] = s*Y[best+j];
399       E += P[j]*P[j];
400    }
401    E = pred_gain/sqrt(E);
402    for (j=0;j<N;j++)
403       P[j] *= E;
404    if (K==0)
405    {
406       for (j=0;j<N;j++)
407          x[j] = P[j];
408    }
409 }
410
411 void intra_fold(float *x, int N, int K, float *Y, float *P, int B, int N0)
412 {
413    int j;
414    float E;
415    
416    E = 1e-10;
417    for (j=0;j<N;j++)
418    {
419       P[j] = Y[j];
420       E += P[j]*P[j];
421    }
422    E = 1.f/sqrt(E);
423    for (j=0;j<N;j++)
424       P[j] *= E;
425    for (j=0;j<N;j++)
426       x[j] = P[j];
427 }
428