Propagating perceptual weighting around (not used yet).
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include <math.h>
33 #include <stdlib.h>
34 #include "cwrs.h"
35 #include "vq.h"
36
37
38 /* Improved algebraic pulse-base quantiser. The signal x is replaced by the sum of the pitch 
39    a combination of pulses such that its norm is still equal to 1. The only difference with 
40    the quantiser above is that the search is more complete. */
41 void alg_quant(float *x, float *W, int N, int K, float *p, float alpha, ec_enc *enc)
42 {
43    int L = 3;
44    //float tata[200];
45    float y[L][N];
46    int iy[L][N];
47    //float tata2[200];
48    float ny[L][N];
49    int iny[L][N];
50    int i, j, m;
51    float xy[L], nxy[L];
52    float yy[L], nyy[L];
53    float yp[L], nyp[L];
54    float best_scores[L];
55    float Rpp=0, Rxp=0;
56    float gain[L];
57    int maxL = 1;
58    
59    for (j=0;j<N;j++)
60       Rpp += p[j]*p[j];
61    //if (Rpp>.01)
62    //   alpha = (1-sqrt(1-Rpp))/Rpp;
63    for (j=0;j<N;j++)
64       Rxp += x[j]*p[j];
65    for (m=0;m<L;m++)
66       for (i=0;i<N;i++)
67          y[m][i] = 0;
68       
69    for (m=0;m<L;m++)
70       for (i=0;i<N;i++)
71          ny[m][i] = 0;
72
73    for (m=0;m<L;m++)
74       for (i=0;i<N;i++)
75          iy[m][i] = iny[m][i] = 0;
76
77    for (m=0;m<L;m++)
78       xy[m] = yy[m] = yp[m] = gain[m] = 0;
79    
80    for (i=0;i<K;i++)
81    {
82       int L2 = L;
83       if (L>maxL)
84       {
85          L2 = maxL;
86          maxL *= N;
87       }
88       for (m=0;m<L;m++)
89          best_scores[m] = -1e10;
90
91       for (m=0;m<L2;m++)
92       {
93          for (j=0;j<N;j++)
94          {
95             int sign;
96             for (sign=-1;sign<=1;sign+=2)
97             {
98                if (iy[m][j]*sign < 0)
99                   continue;
100                //fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);
101                float tmp_xy, tmp_yy, tmp_yp;
102                float score;
103                float g;
104                float s = sign;
105                tmp_xy = xy[m] + s*x[j]               - alpha*s*p[j]*Rxp;
106                tmp_yy = yy[m] + 2*s*y[m][j] + 1      +alpha*alpha*p[j]*p[j]*Rpp - 2*alpha*s*p[j]*yp[m] - 2*alpha*p[j]*p[j];
107                tmp_yp = yp[m] + s*p[j]               *(1-alpha*Rpp);
108                g = (sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)/tmp_yy;
109                score = 2*g*tmp_xy - g*g*tmp_yy;
110
111                if (score>best_scores[L-1])
112                {
113                   int k, n;
114                   int id = L-1;
115                   while (id > 0 && score > best_scores[id-1])
116                      id--;
117                
118                   for (k=L-1;k>id;k--)
119                   {
120                      nxy[k] = nxy[k-1];
121                      nyy[k] = nyy[k-1];
122                      nyp[k] = nyp[k-1];
123                      //fprintf(stderr, "%d %d \n", N, k);
124                      for (n=0;n<N;n++)
125                         ny[k][n] = ny[k-1][n];
126                      for (n=0;n<N;n++)
127                         iny[k][n] = iny[k-1][n];
128                      gain[k] = gain[k-1];
129                      best_scores[k] = best_scores[k-1];
130                   }
131
132                   nxy[id] = tmp_xy;
133                   nyy[id] = tmp_yy;
134                   nyp[id] = tmp_yp;
135                   gain[id] = g;
136                   for (n=0;n<N;n++)
137                      ny[id][n] = y[m][n];
138                   ny[id][j] += s;
139                   for (n=0;n<N;n++)
140                      ny[id][n] -= alpha*s*p[j]*p[n];
141                
142                   for (n=0;n<N;n++)
143                      iny[id][n] = iy[m][n];
144                   if (s>0)
145                      iny[id][j] += 1;
146                   else
147                      iny[id][j] -= 1;
148                   best_scores[id] = score;
149                }
150             }   
151          }
152          
153       }
154       int k,n;
155       for (k=0;k<L;k++)
156       {
157          xy[k] = nxy[k];
158          yy[k] = nyy[k];
159          yp[k] = nyp[k];
160          for (n=0;n<N;n++)
161             y[k][n] = ny[k][n];
162          for (n=0;n<N;n++)
163             iy[k][n] = iny[k][n];
164       }
165
166    }
167    
168    for (i=0;i<N;i++)
169       x[i] = p[i]+gain[0]*y[0][i];
170    if (0) {
171       float E=1e-15;
172       int ABS = 0;
173       for (i=0;i<N;i++)
174          ABS += abs(iy[0][i]);
175       //if (K != ABS)
176       //   printf ("%d %d\n", K, ABS);
177       for (i=0;i<N;i++)
178          E += x[i]*x[i];
179       //printf ("%f\n", E);
180       E = 1/sqrt(E);
181       for (i=0;i<N;i++)
182          x[i] *= E;
183    }
184    int comb[K];
185    int signs[K];
186    //for (i=0;i<N;i++)
187    //   printf ("%d ", iy[0][i]);
188    pulse2comb(N, K, comb, signs, iy[0]); 
189    ec_enc_uint64(enc,icwrs64(N, K, comb, signs),ncwrs64(N, K));
190    
191    /* Recompute the gain in one pass (to reduce errors) */
192    if (0) {
193       float Ryp=0;
194       float Rpp=0;
195       float Ryy=0;
196       float g=0;
197       for (i=0;i<N;i++)
198          Rpp += p[i]*p[i];
199       
200       for (i=0;i<N;i++)
201          Ryp += iy[0][i]*p[i];
202       
203       for (i=0;i<N;i++)
204          y[0][i] = iy[0][i] - alpha*Ryp*p[i];
205       
206       /* Recompute after the projection (I think it's right) */
207       Ryp = 0;
208       for (i=0;i<N;i++)
209          Ryp += y[0][i]*p[i];
210       
211       for (i=0;i<N;i++)
212          Ryy += y[0][i]*y[0][i];
213       
214       g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
215         
216       for (i=0;i<N;i++)
217          x[i] = p[i] + g*y[0][i];
218       
219    }
220
221 }
222
223 static const float pg[5] = {1.f, .6f, .45f, 0.35f, 0.25f};
224
225 /* Finds the right offset into Y and copy it */
226 void copy_quant(float *x, float *W, int N, int K, float *Y, int B, int N0, ec_enc *enc)
227 {
228    int i,j;
229    int best=0;
230    float best_score=0;
231    float s = 1;
232    int sign;
233    float E;
234    for (i=0;i<N0*B-N;i+=B)
235    {
236       int j;
237       float xy=0, yy=0;
238       float score;
239       for (j=0;j<N;j++)
240       {
241          xy += x[j]*Y[i+j];
242          yy += Y[i+j]*Y[i+j];
243       }
244       score = xy*xy/(.001+yy);
245       if (score > best_score)
246       {
247          best_score = score;
248          best = i;
249          if (xy>0)
250             s = 1;
251          else
252             s = -1;
253       }
254    }
255    if (s<0)
256       sign = 1;
257    else
258       sign = 0;
259    //printf ("%d %d ", sign, best);
260    ec_enc_uint(enc,sign,2);
261    ec_enc_uint(enc,best/B,N0-N/B);
262    //printf ("%d %f\n", best, best_score);
263    if (K==0)
264    {
265       E = 1e-10;
266       for (j=0;j<N;j++)
267       {
268          x[j] = s*Y[best+j];
269          E += x[j]*x[j];
270       }
271       E = 1/sqrt(E);
272       for (j=0;j<N;j++)
273          x[j] *= E;
274    } else {
275       float P[N];
276       float pred_gain;
277       if (K>4)
278          pred_gain = .5;
279       else
280          pred_gain = pg[K];
281       E = 1e-10;
282       for (j=0;j<N;j++)
283       {
284          P[j] = s*Y[best+j];
285          E += P[j]*P[j];
286       }
287       E = .8/sqrt(E);
288       for (j=0;j<N;j++)
289          P[j] *= E;
290       alg_quant(x, W, N, K, P, 0, enc);
291    }
292 }
293
294 void alg_unquant(float *x, int N, int K, float *p, float alpha, ec_dec *dec)
295 {
296    int i;
297    celt_uint64_t id;
298    int comb[K];
299    int signs[K];
300    int iy[N];
301    float y[N];
302    float Rpp=0, Ryp=0, Ryy=0;
303    float g;
304    
305    id = ec_dec_uint64(dec, ncwrs64(N, K));
306    cwrsi64(N, K, id, comb, signs);
307    comb2pulse(N, K, iy, comb, signs);
308    //for (i=0;i<N;i++)
309    //   printf ("%d ", iy[i]);
310    for (i=0;i<N;i++)
311       Rpp += p[i]*p[i];
312
313    for (i=0;i<N;i++)
314       Ryp += iy[i]*p[i];
315
316    for (i=0;i<N;i++)
317       y[i] = iy[i] - alpha*Ryp*p[i];
318    
319    /* Recompute after the projection (I think it's right) */
320    Ryp = 0;
321    for (i=0;i<N;i++)
322       Ryp += y[i]*p[i];
323    
324    for (i=0;i<N;i++)
325       Ryy += y[i]*y[i];
326
327    g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
328    
329    for (i=0;i<N;i++)
330       x[i] = p[i] + g*y[i];
331 }
332
333 void copy_unquant(float *x, int N, int K, float *Y, int B, int N0, ec_dec *dec)
334 {
335    int j;
336    int sign;
337    float s;
338    int best;
339    float E;
340    sign = ec_dec_uint(dec, 2);
341    if (sign == 0)
342       s = 1;
343    else
344       s = -1;
345    
346    best = B*ec_dec_uint(dec, N0-N/B);
347    //printf ("%d %d ", sign, best);
348
349    if (K==0)
350    {
351       E = 1e-10;
352       for (j=0;j<N;j++)
353       {
354          x[j] = s*Y[best+j];
355          E += x[j]*x[j];
356       }
357       E = 1/sqrt(E);
358       for (j=0;j<N;j++)
359          x[j] *= E;
360    } else {
361       float P[N];
362       float pred_gain;
363       if (K>4)
364          pred_gain = .5;
365       else
366          pred_gain = pg[K];
367       E = 1e-10;
368       for (j=0;j<N;j++)
369       {
370          P[j] = s*Y[best+j];
371          E += P[j]*P[j];
372       }
373       E = .8/sqrt(E);
374       for (j=0;j<N;j++)
375          P[j] *= E;
376       alg_unquant(x, N, K, P, 0, dec);
377    }
378 }