Nearly working cheating decoder.
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include <math.h>
33 #include <stdlib.h>
34 #include "cwrs.h"
35 #include "vq.h"
36
37
38 /* Improved algebraic pulse-base quantiser. The signal x is replaced by the sum of the pitch 
39    a combination of pulses such that its norm is still equal to 1. The only difference with 
40    the quantiser above is that the search is more complete. */
41 void alg_quant(float *x, int N, int K, float *p, ec_enc *enc)
42 {
43    int L = 5;
44    //float tata[200];
45    float y[L][N];
46    int iy[L][N];
47    //float tata2[200];
48    float ny[L][N];
49    int iny[L][N];
50    int i, j, m;
51    float xy[L], nxy[L];
52    float yy[L], nyy[L];
53    float yp[L], nyp[L];
54    float best_scores[L];
55    float Rpp=0, Rxp=0;
56    float gain[L];
57    int maxL = 1;
58    float alpha = .9;
59    
60    for (j=0;j<N;j++)
61       Rpp += p[j]*p[j];
62    //if (Rpp>.01)
63    //   alpha = (1-sqrt(1-Rpp))/Rpp;
64    for (j=0;j<N;j++)
65       Rxp += x[j]*p[j];
66    for (m=0;m<L;m++)
67       for (i=0;i<N;i++)
68          y[m][i] = 0;
69       
70    for (m=0;m<L;m++)
71       for (i=0;i<N;i++)
72          ny[m][i] = 0;
73
74    for (m=0;m<L;m++)
75       for (i=0;i<N;i++)
76          iy[m][i] = iny[m][i] = 0;
77
78    for (m=0;m<L;m++)
79       xy[m] = yy[m] = yp[m] = gain[m] = 0;
80    
81    for (i=0;i<K;i++)
82    {
83       int L2 = L;
84       if (L>maxL)
85       {
86          L2 = maxL;
87          maxL *= N;
88       }
89       for (m=0;m<L;m++)
90          best_scores[m] = -1e10;
91
92       for (m=0;m<L2;m++)
93       {
94          for (j=0;j<N;j++)
95          {
96             int sign;
97             for (sign=-1;sign<=1;sign+=2)
98             {
99                if (iy[m][j]*sign < 0)
100                   continue;
101                //fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);
102                float tmp_xy, tmp_yy, tmp_yp;
103                float score;
104                float g;
105                float s = sign;
106                tmp_xy = xy[m] + s*x[j]               - alpha*s*p[j]*Rxp;
107                tmp_yy = yy[m] + 2*s*y[m][j] + 1      +alpha*alpha*p[j]*p[j]*Rpp - 2*alpha*s*p[j]*yp[m] - 2*alpha*p[j]*p[j];
108                tmp_yp = yp[m] + s*p[j]               *(1-alpha*Rpp);
109                g = (sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)/tmp_yy;
110                score = 2*g*tmp_xy - g*g*tmp_yy;
111
112                if (score>best_scores[L-1])
113                {
114                   int k, n;
115                   int id = L-1;
116                   while (id > 0 && score > best_scores[id-1])
117                      id--;
118                
119                   for (k=L-1;k>id;k--)
120                   {
121                      nxy[k] = nxy[k-1];
122                      nyy[k] = nyy[k-1];
123                      nyp[k] = nyp[k-1];
124                      //fprintf(stderr, "%d %d \n", N, k);
125                      for (n=0;n<N;n++)
126                         ny[k][n] = ny[k-1][n];
127                      for (n=0;n<N;n++)
128                         iny[k][n] = iny[k-1][n];
129                      gain[k] = gain[k-1];
130                      best_scores[k] = best_scores[k-1];
131                   }
132
133                   nxy[id] = tmp_xy;
134                   nyy[id] = tmp_yy;
135                   nyp[id] = tmp_yp;
136                   gain[id] = g;
137                   for (n=0;n<N;n++)
138                      ny[id][n] = y[m][n];
139                   ny[id][j] += s;
140                   for (n=0;n<N;n++)
141                      ny[id][n] -= alpha*s*p[j]*p[n];
142                
143                   for (n=0;n<N;n++)
144                      iny[id][n] = iy[m][n];
145                   if (s>0)
146                      iny[id][j] += 1;
147                   else
148                      iny[id][j] -= 1;
149                   best_scores[id] = score;
150                }
151             }   
152          }
153          
154       }
155       int k,n;
156       for (k=0;k<L;k++)
157       {
158          xy[k] = nxy[k];
159          yy[k] = nyy[k];
160          yp[k] = nyp[k];
161          for (n=0;n<N;n++)
162             y[k][n] = ny[k][n];
163          for (n=0;n<N;n++)
164             iy[k][n] = iny[k][n];
165       }
166
167    }
168    
169    for (i=0;i<N;i++)
170       x[i] = p[i]+gain[0]*y[0][i];
171    if (0) {
172       float E=1e-15;
173       int ABS = 0;
174       for (i=0;i<N;i++)
175          ABS += abs(iy[0][i]);
176       //if (K != ABS)
177       //   printf ("%d %d\n", K, ABS);
178       for (i=0;i<N;i++)
179          E += x[i]*x[i];
180       //printf ("%f\n", E);
181       E = 1/sqrt(E);
182       for (i=0;i<N;i++)
183          x[i] *= E;
184    }
185    int comb[K];
186    int signs[K];
187    //for (i=0;i<N;i++)
188    //   printf ("%d ", iy[0][i]);
189    pulse2comb(N, K, comb, signs, iy[0]); 
190    ec_enc_uint(enc,icwrs(N, K, comb, signs),ncwrs(N, K));
191 }
192
193 static const float pg[5] = {1.f, .82f, .75f, 0.7f, 0.6f};
194
195 /* Finds the right offset into Y and copy it */
196 void copy_quant(float *x, int N, int K, float *Y, int B, int N0, ec_enc *enc)
197 {
198    int i,j;
199    int best=0;
200    float best_score=0;
201    float s = 1;
202    float E;
203    for (i=0;i<N0*B-N;i+=B)
204    {
205       int j;
206       float xy=0, yy=0;
207       float score;
208       for (j=0;j<N;j++)
209       {
210          xy += x[j]*Y[i+j];
211          yy += Y[i+j]*Y[i+j];
212       }
213       score = xy*xy/(.001+yy);
214       if (score > best_score)
215       {
216          best_score = score;
217          best = i;
218          if (xy>0)
219             s = 1;
220          else
221             s = -1;
222       }
223    }
224    //printf ("e%d e%d ", s, best);
225    if (s==-1)
226       ec_enc_uint(enc,1,1);
227    else
228       ec_enc_uint(enc,0,1);
229    ec_enc_uint(enc,best/B,N0-N/B);
230    //printf ("%d %f\n", best, best_score);
231    if (K==0)
232    {
233       E = 1e-10;
234       for (j=0;j<N;j++)
235       {
236          x[j] = s*Y[best+j];
237          E += x[j]*x[j];
238       }
239       E = 1/sqrt(E);
240       for (j=0;j<N;j++)
241          x[j] *= E;
242    } else {
243       float P[N];
244       float pred_gain;
245       if (K>4)
246          pred_gain = .5;
247       else
248          pred_gain = pg[K];
249       E = 1e-10;
250       for (j=0;j<N;j++)
251       {
252          P[j] = s*Y[best+j];
253          E += P[j]*P[j];
254       }
255       E = .8/sqrt(E);
256       for (j=0;j<N;j++)
257          P[j] *= E;
258       alg_quant(x, N, K, P, enc);
259    }
260 }
261
262 void alg_unquant(float *x, int N, int K, float *p, ec_dec *dec)
263 {
264    int i;
265    unsigned int id;
266    int comb[K];
267    int signs[K];
268    int iy[N];
269    float y[N];
270    float alpha = .9;
271    float Rpp=0, Ryp=0, Ryy=0;
272    float g;
273    
274    id = ec_dec_uint(dec, ncwrs(N, K));
275    cwrsi(N, K, id, comb, signs);
276    comb2pulse(N, K, iy, comb, signs);
277    //for (i=0;i<N;i++)
278    //   printf ("%d ", iy[i]);
279    for (i=0;i<N;i++)
280       Rpp += p[i]*p[i];
281
282    for (i=0;i<N;i++)
283       Ryp += iy[i]*p[i];
284
285    for (i=0;i<N;i++)
286       y[i] = iy[i] - alpha*Ryp*p[i];
287    
288    /* Recompute after the projection (I think it's right) */
289    Ryp = 0;
290    for (i=0;i<N;i++)
291       Ryp += y[i]*p[i];
292    
293    for (i=0;i<N;i++)
294       Ryy += y[i]*y[i];
295
296    g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
297    
298    for (i=0;i<N;i++)
299       x[i] = p[i] + g*y[i];
300 }
301
302 void copy_unquant(float *x, int N, int K, float *Y, int B, int N0, ec_dec *dec)
303 {
304    int i,j;
305    int s;
306    int best;
307    float E;
308    if (ec_dec_uint(dec, 1) == 0)
309       s = 1;
310    else
311       s = -1;
312    
313    best = B*ec_dec_uint(dec, N0-N/B);
314    printf ("d%d d%d ", s, best);
315
316    if (K==0)
317    {
318       E = 1e-10;
319       for (j=0;j<N;j++)
320       {
321          x[j] = s*Y[best+j];
322          E += x[j]*x[j];
323       }
324       E = 1/sqrt(E);
325       for (j=0;j<N;j++)
326          x[j] *= E;
327    } else {
328       float P[N];
329       float pred_gain;
330       if (K>4)
331          pred_gain = .5;
332       else
333          pred_gain = pg[K];
334       E = 1e-10;
335       for (j=0;j<N;j++)
336       {
337          P[j] = s*Y[best+j];
338          E += P[j]*P[j];
339       }
340       E = .8/sqrt(E);
341       for (j=0;j<N;j++)
342          P[j] *= E;
343       alg_unquant(x, N, K, P, dec);
344    }
345 }