Allowing the quantiser serch to put more than one pulse at one,
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include <math.h>
33 #include <stdlib.h>
34 #include "cwrs.h"
35 #include "vq.h"
36
37
38 /* Improved algebraic pulse-base quantiser. The signal x is replaced by the sum of the pitch 
39    a combination of pulses such that its norm is still equal to 1. The only difference with 
40    the quantiser above is that the search is more complete. */
41 void alg_quant(float *x, float *W, int N, int K, float *p, float alpha, ec_enc *enc)
42 {
43    int L = 3;
44    //float tata[200];
45    float y[L][N];
46    int iy[L][N];
47    //float tata2[200];
48    float ny[L][N];
49    int iny[L][N];
50    int i, j, m;
51    int pulsesLeft;
52    float xy[L], nxy[L];
53    float yy[L], nyy[L];
54    float yp[L], nyp[L];
55    float best_scores[L];
56    float Rpp=0, Rxp=0;
57    float gain[L];
58    int maxL = 1;
59    
60    for (j=0;j<N;j++)
61       Rpp += p[j]*p[j];
62    //if (Rpp>.01)
63    //   alpha = (1-sqrt(1-Rpp))/Rpp;
64    for (j=0;j<N;j++)
65       Rxp += x[j]*p[j];
66    for (m=0;m<L;m++)
67       for (i=0;i<N;i++)
68          y[m][i] = 0;
69       
70    for (m=0;m<L;m++)
71       for (i=0;i<N;i++)
72          ny[m][i] = 0;
73
74    for (m=0;m<L;m++)
75       for (i=0;i<N;i++)
76          iy[m][i] = iny[m][i] = 0;
77
78    for (m=0;m<L;m++)
79       xy[m] = yy[m] = yp[m] = gain[m] = 0;
80    
81    pulsesLeft = K;
82    while (pulsesLeft > 0)
83    {
84       int pulsesAtOnce=1;
85       int L2 = L;
86       if (L>maxL)
87       {
88          L2 = maxL;
89          maxL *= N;
90       }
91       if (pulsesLeft > 5)
92          L2 = 1;
93       
94       pulsesAtOnce = pulsesLeft/N;
95       if (pulsesAtOnce<1)
96          pulsesAtOnce = 1;
97
98       for (m=0;m<L;m++)
99          best_scores[m] = -1e10;
100
101       for (m=0;m<L2;m++)
102       {
103          for (j=0;j<N;j++)
104          {
105             int sign;
106             for (sign=-1;sign<=1;sign+=2)
107             {
108                if (iy[m][j]*sign < 0)
109                   continue;
110                //fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);
111                float tmp_xy, tmp_yy, tmp_yp;
112                float score;
113                float g;
114                float s = sign*pulsesAtOnce;
115                tmp_xy = xy[m] + s*x[j]               - alpha*s*p[j]*Rxp;
116                tmp_yy = yy[m] + 2*s*y[m][j] + s*s      +s*s*alpha*alpha*p[j]*p[j]*Rpp - 2*alpha*s*p[j]*yp[m] - 2*s*s*alpha*p[j]*p[j];
117                tmp_yp = yp[m] + s*p[j]               *(1-alpha*Rpp);
118                g = (sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)/tmp_yy;
119                score = 2*g*tmp_xy - g*g*tmp_yy;
120
121                if (score>best_scores[L-1])
122                {
123                   int k, n;
124                   int id = L-1;
125                   while (id > 0 && score > best_scores[id-1])
126                      id--;
127                
128                   for (k=L-1;k>id;k--)
129                   {
130                      nxy[k] = nxy[k-1];
131                      nyy[k] = nyy[k-1];
132                      nyp[k] = nyp[k-1];
133                      //fprintf(stderr, "%d %d \n", N, k);
134                      for (n=0;n<N;n++)
135                         ny[k][n] = ny[k-1][n];
136                      for (n=0;n<N;n++)
137                         iny[k][n] = iny[k-1][n];
138                      gain[k] = gain[k-1];
139                      best_scores[k] = best_scores[k-1];
140                   }
141
142                   nxy[id] = tmp_xy;
143                   nyy[id] = tmp_yy;
144                   nyp[id] = tmp_yp;
145                   gain[id] = g;
146                   for (n=0;n<N;n++)
147                      ny[id][n] = y[m][n];
148                   ny[id][j] += s;
149                   for (n=0;n<N;n++)
150                      ny[id][n] -= alpha*s*p[j]*p[n];
151                
152                   for (n=0;n<N;n++)
153                      iny[id][n] = iy[m][n];
154                   if (s>0)
155                      iny[id][j] += pulsesAtOnce;
156                   else
157                      iny[id][j] -= pulsesAtOnce;
158                   best_scores[id] = score;
159                }
160             }   
161          }
162          
163       }
164       int k,n;
165       /* FIXME: We could be swapping pointers instead */
166       for (k=0;k<L;k++)
167       {
168          xy[k] = nxy[k];
169          yy[k] = nyy[k];
170          yp[k] = nyp[k];
171          for (n=0;n<N;n++)
172             y[k][n] = ny[k][n];
173          for (n=0;n<N;n++)
174             iy[k][n] = iny[k][n];
175       }
176       pulsesLeft -= pulsesAtOnce;
177    }
178    
179    for (i=0;i<N;i++)
180       x[i] = p[i]+gain[0]*y[0][i];
181    if (0) {
182       float E=1e-15;
183       int ABS = 0;
184       for (i=0;i<N;i++)
185          ABS += abs(iy[0][i]);
186       //if (K != ABS)
187       //   printf ("%d %d\n", K, ABS);
188       for (i=0;i<N;i++)
189          E += x[i]*x[i];
190       //printf ("%f\n", E);
191       E = 1/sqrt(E);
192       for (i=0;i<N;i++)
193          x[i] *= E;
194    }
195    int comb[K];
196    int signs[K];
197    //for (i=0;i<N;i++)
198    //   printf ("%d ", iy[0][i]);
199    pulse2comb(N, K, comb, signs, iy[0]); 
200    ec_enc_uint64(enc,icwrs64(N, K, comb, signs),ncwrs64(N, K));
201    //printf ("%llu ", icwrs64(N, K, comb, signs));
202    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
203       due to the recursive computation used in quantisation */
204    if (1) {
205       float Ryp=0;
206       float Rpp=0;
207       float Ryy=0;
208       float g=0;
209       for (i=0;i<N;i++)
210          Rpp += p[i]*p[i];
211       
212       for (i=0;i<N;i++)
213          Ryp += iy[0][i]*p[i];
214       
215       for (i=0;i<N;i++)
216          y[0][i] = iy[0][i] - alpha*Ryp*p[i];
217       
218       Ryp = 0;
219       for (i=0;i<N;i++)
220          Ryp += y[0][i]*p[i];
221       
222       for (i=0;i<N;i++)
223          Ryy += y[0][i]*y[0][i];
224       
225       g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
226         
227       for (i=0;i<N;i++)
228          x[i] = p[i] + g*y[0][i];
229       
230    }
231
232 }
233
234 void alg_unquant(float *x, int N, int K, float *p, float alpha, ec_dec *dec)
235 {
236    int i;
237    celt_uint64_t id;
238    int comb[K];
239    int signs[K];
240    int iy[N];
241    float y[N];
242    float Rpp=0, Ryp=0, Ryy=0;
243    float g;
244
245    id = ec_dec_uint64(dec, ncwrs64(N, K));
246    //printf ("%llu ", id);
247    cwrsi64(N, K, id, comb, signs);
248    comb2pulse(N, K, iy, comb, signs);
249    //for (i=0;i<N;i++)
250    //   printf ("%d ", iy[i]);
251    for (i=0;i<N;i++)
252       Rpp += p[i]*p[i];
253
254    for (i=0;i<N;i++)
255       Ryp += iy[i]*p[i];
256
257    for (i=0;i<N;i++)
258       y[i] = iy[i] - alpha*Ryp*p[i];
259
260    /* Recompute after the projection (I think it's right) */
261    Ryp = 0;
262    for (i=0;i<N;i++)
263       Ryp += y[i]*p[i];
264
265    for (i=0;i<N;i++)
266       Ryy += y[i]*y[i];
267
268    g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
269
270    for (i=0;i<N;i++)
271       x[i] = p[i] + g*y[i];
272 }
273
274
275 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
276
277 void intra_prediction(float *x, float *W, int N, int K, float *Y, float *P, int B, int N0, ec_enc *enc)
278 {
279    int i,j;
280    int best=0;
281    float best_score=0;
282    float s = 1;
283    int sign;
284    float E;
285    int max_pos = N0-N/B;
286    if (max_pos > 32)
287       max_pos = 32;
288
289    for (i=0;i<max_pos*B;i+=B)
290    {
291       int j;
292       float xy=0, yy=0;
293       float score;
294       for (j=0;j<N;j++)
295       {
296          xy += x[j]*Y[i+j];
297          yy += Y[i+j]*Y[i+j];
298       }
299       score = xy*xy/(.001+yy);
300       if (score > best_score)
301       {
302          best_score = score;
303          best = i;
304          if (xy>0)
305             s = 1;
306          else
307             s = -1;
308       }
309    }
310    if (s<0)
311       sign = 1;
312    else
313       sign = 0;
314    //printf ("%d %d ", sign, best);
315    ec_enc_uint(enc,sign,2);
316    ec_enc_uint(enc,best/B,max_pos);
317    //printf ("%d %f\n", best, best_score);
318    
319    float pred_gain;
320    if (K>10)
321       pred_gain = pg[10];
322    else
323       pred_gain = pg[K];
324    E = 1e-10;
325    for (j=0;j<N;j++)
326    {
327       P[j] = s*Y[best+j];
328       E += P[j]*P[j];
329    }
330    E = pred_gain/sqrt(E);
331    for (j=0;j<N;j++)
332       P[j] *= E;
333    if (K>0)
334    {
335       for (j=0;j<N;j++)
336          x[j] -= P[j];
337    } else {
338       for (j=0;j<N;j++)
339          x[j] = P[j];
340    }
341    //printf ("quant ");
342    //for (j=0;j<N;j++) printf ("%f ", P[j]);
343
344 }
345
346 void intra_unquant(float *x, int N, int K, float *Y, float *P, int B, int N0, ec_dec *dec)
347 {
348    int j;
349    int sign;
350    float s;
351    int best;
352    float E;
353    int max_pos = N0-N/B;
354    if (max_pos > 32)
355       max_pos = 32;
356    
357    sign = ec_dec_uint(dec, 2);
358    if (sign == 0)
359       s = 1;
360    else
361       s = -1;
362    
363    best = B*ec_dec_uint(dec, max_pos);
364    //printf ("%d %d ", sign, best);
365
366    float pred_gain;
367    if (K>10)
368       pred_gain = pg[10];
369    else
370       pred_gain = pg[K];
371    E = 1e-10;
372    for (j=0;j<N;j++)
373    {
374       P[j] = s*Y[best+j];
375       E += P[j]*P[j];
376    }
377    E = pred_gain/sqrt(E);
378    for (j=0;j<N;j++)
379       P[j] *= E;
380    if (K==0)
381    {
382       for (j=0;j<N;j++)
383          x[j] = P[j];
384    }
385 }
386
387 void intra_fold(float *x, int N, int K, float *Y, float *P, int B, int N0)
388 {
389    int j;
390    float E;
391    
392    E = 1e-10;
393    for (j=0;j<N;j++)
394    {
395       P[j] = Y[j];
396       E += P[j]*P[j];
397    }
398    E = 1.f/sqrt(E);
399    for (j=0;j<N;j++)
400       P[j] *= E;
401    for (j=0;j<N;j++)
402       x[j] = P[j];
403 }
404