Cheating decoder now produces the same result as the encoder
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include <math.h>
33 #include <stdlib.h>
34 #include "cwrs.h"
35 #include "vq.h"
36
37
38 /* Improved algebraic pulse-base quantiser. The signal x is replaced by the sum of the pitch 
39    a combination of pulses such that its norm is still equal to 1. The only difference with 
40    the quantiser above is that the search is more complete. */
41 void alg_quant(float *x, int N, int K, float *p, ec_enc *enc)
42 {
43    int L = 5;
44    //float tata[200];
45    float y[L][N];
46    int iy[L][N];
47    //float tata2[200];
48    float ny[L][N];
49    int iny[L][N];
50    int i, j, m;
51    float xy[L], nxy[L];
52    float yy[L], nyy[L];
53    float yp[L], nyp[L];
54    float best_scores[L];
55    float Rpp=0, Rxp=0;
56    float gain[L];
57    int maxL = 1;
58    float alpha = .9;
59    
60    for (j=0;j<N;j++)
61       Rpp += p[j]*p[j];
62    //if (Rpp>.01)
63    //   alpha = (1-sqrt(1-Rpp))/Rpp;
64    for (j=0;j<N;j++)
65       Rxp += x[j]*p[j];
66    for (m=0;m<L;m++)
67       for (i=0;i<N;i++)
68          y[m][i] = 0;
69       
70    for (m=0;m<L;m++)
71       for (i=0;i<N;i++)
72          ny[m][i] = 0;
73
74    for (m=0;m<L;m++)
75       for (i=0;i<N;i++)
76          iy[m][i] = iny[m][i] = 0;
77
78    for (m=0;m<L;m++)
79       xy[m] = yy[m] = yp[m] = gain[m] = 0;
80    
81    for (i=0;i<K;i++)
82    {
83       int L2 = L;
84       if (L>maxL)
85       {
86          L2 = maxL;
87          maxL *= N;
88       }
89       for (m=0;m<L;m++)
90          best_scores[m] = -1e10;
91
92       for (m=0;m<L2;m++)
93       {
94          for (j=0;j<N;j++)
95          {
96             int sign;
97             for (sign=-1;sign<=1;sign+=2)
98             {
99                if (iy[m][j]*sign < 0)
100                   continue;
101                //fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);
102                float tmp_xy, tmp_yy, tmp_yp;
103                float score;
104                float g;
105                float s = sign;
106                tmp_xy = xy[m] + s*x[j]               - alpha*s*p[j]*Rxp;
107                tmp_yy = yy[m] + 2*s*y[m][j] + 1      +alpha*alpha*p[j]*p[j]*Rpp - 2*alpha*s*p[j]*yp[m] - 2*alpha*p[j]*p[j];
108                tmp_yp = yp[m] + s*p[j]               *(1-alpha*Rpp);
109                g = (sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)/tmp_yy;
110                score = 2*g*tmp_xy - g*g*tmp_yy;
111
112                if (score>best_scores[L-1])
113                {
114                   int k, n;
115                   int id = L-1;
116                   while (id > 0 && score > best_scores[id-1])
117                      id--;
118                
119                   for (k=L-1;k>id;k--)
120                   {
121                      nxy[k] = nxy[k-1];
122                      nyy[k] = nyy[k-1];
123                      nyp[k] = nyp[k-1];
124                      //fprintf(stderr, "%d %d \n", N, k);
125                      for (n=0;n<N;n++)
126                         ny[k][n] = ny[k-1][n];
127                      for (n=0;n<N;n++)
128                         iny[k][n] = iny[k-1][n];
129                      gain[k] = gain[k-1];
130                      best_scores[k] = best_scores[k-1];
131                   }
132
133                   nxy[id] = tmp_xy;
134                   nyy[id] = tmp_yy;
135                   nyp[id] = tmp_yp;
136                   gain[id] = g;
137                   for (n=0;n<N;n++)
138                      ny[id][n] = y[m][n];
139                   ny[id][j] += s;
140                   for (n=0;n<N;n++)
141                      ny[id][n] -= alpha*s*p[j]*p[n];
142                
143                   for (n=0;n<N;n++)
144                      iny[id][n] = iy[m][n];
145                   if (s>0)
146                      iny[id][j] += 1;
147                   else
148                      iny[id][j] -= 1;
149                   best_scores[id] = score;
150                }
151             }   
152          }
153          
154       }
155       int k,n;
156       for (k=0;k<L;k++)
157       {
158          xy[k] = nxy[k];
159          yy[k] = nyy[k];
160          yp[k] = nyp[k];
161          for (n=0;n<N;n++)
162             y[k][n] = ny[k][n];
163          for (n=0;n<N;n++)
164             iy[k][n] = iny[k][n];
165       }
166
167    }
168    
169    for (i=0;i<N;i++)
170       x[i] = p[i]+gain[0]*y[0][i];
171    if (0) {
172       float E=1e-15;
173       int ABS = 0;
174       for (i=0;i<N;i++)
175          ABS += abs(iy[0][i]);
176       //if (K != ABS)
177       //   printf ("%d %d\n", K, ABS);
178       for (i=0;i<N;i++)
179          E += x[i]*x[i];
180       //printf ("%f\n", E);
181       E = 1/sqrt(E);
182       for (i=0;i<N;i++)
183          x[i] *= E;
184    }
185    int comb[K];
186    int signs[K];
187    //for (i=0;i<N;i++)
188    //   printf ("%d ", iy[0][i]);
189    pulse2comb(N, K, comb, signs, iy[0]); 
190    ec_enc_uint(enc,icwrs(N, K, comb, signs),ncwrs(N, K));
191 }
192
193 static const float pg[5] = {1.f, .82f, .75f, 0.7f, 0.6f};
194
195 /* Finds the right offset into Y and copy it */
196 void copy_quant(float *x, int N, int K, float *Y, int B, int N0, ec_enc *enc)
197 {
198    int i,j;
199    int best=0;
200    float best_score=0;
201    float s = 1;
202    int sign;
203    float E;
204    for (i=0;i<N0*B-N;i+=B)
205    {
206       int j;
207       float xy=0, yy=0;
208       float score;
209       for (j=0;j<N;j++)
210       {
211          xy += x[j]*Y[i+j];
212          yy += Y[i+j]*Y[i+j];
213       }
214       score = xy*xy/(.001+yy);
215       if (score > best_score)
216       {
217          best_score = score;
218          best = i;
219          if (xy>0)
220             s = 1;
221          else
222             s = -1;
223       }
224    }
225    if (s<0)
226       sign = 1;
227    else
228       sign = 0;
229    //printf ("%d %d ", sign, best);
230    ec_enc_uint(enc,sign,2);
231    ec_enc_uint(enc,best/B,N0-N/B);
232    //printf ("%d %f\n", best, best_score);
233    if (K==0)
234    {
235       E = 1e-10;
236       for (j=0;j<N;j++)
237       {
238          x[j] = s*Y[best+j];
239          E += x[j]*x[j];
240       }
241       E = 1/sqrt(E);
242       for (j=0;j<N;j++)
243          x[j] *= E;
244    } else {
245       float P[N];
246       float pred_gain;
247       if (K>4)
248          pred_gain = .5;
249       else
250          pred_gain = pg[K];
251       E = 1e-10;
252       for (j=0;j<N;j++)
253       {
254          P[j] = s*Y[best+j];
255          E += P[j]*P[j];
256       }
257       E = .8/sqrt(E);
258       for (j=0;j<N;j++)
259          P[j] *= E;
260       alg_quant(x, N, K, P, enc);
261    }
262 }
263
264 void alg_unquant(float *x, int N, int K, float *p, ec_dec *dec)
265 {
266    int i;
267    unsigned int id;
268    int comb[K];
269    int signs[K];
270    int iy[N];
271    float y[N];
272    float alpha = .9;
273    float Rpp=0, Ryp=0, Ryy=0;
274    float g;
275    
276    id = ec_dec_uint(dec, ncwrs(N, K));
277    cwrsi(N, K, id, comb, signs);
278    comb2pulse(N, K, iy, comb, signs);
279    //for (i=0;i<N;i++)
280    //   printf ("%d ", iy[i]);
281    for (i=0;i<N;i++)
282       Rpp += p[i]*p[i];
283
284    for (i=0;i<N;i++)
285       Ryp += iy[i]*p[i];
286
287    for (i=0;i<N;i++)
288       y[i] = iy[i] - alpha*Ryp*p[i];
289    
290    /* Recompute after the projection (I think it's right) */
291    Ryp = 0;
292    for (i=0;i<N;i++)
293       Ryp += y[i]*p[i];
294    
295    for (i=0;i<N;i++)
296       Ryy += y[i]*y[i];
297
298    g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
299    
300    for (i=0;i<N;i++)
301       x[i] = p[i] + g*y[i];
302 }
303
304 void copy_unquant(float *x, int N, int K, float *Y, int B, int N0, ec_dec *dec)
305 {
306    int i,j;
307    int sign;
308    float s;
309    int best;
310    float E;
311    sign = ec_dec_uint(dec, 2);
312    if (sign == 0)
313       s = 1;
314    else
315       s = -1;
316    
317    best = B*ec_dec_uint(dec, N0-N/B);
318    //printf ("%d %d ", sign, best);
319
320    if (K==0)
321    {
322       E = 1e-10;
323       for (j=0;j<N;j++)
324       {
325          x[j] = s*Y[best+j];
326          E += x[j]*x[j];
327       }
328       E = 1/sqrt(E);
329       for (j=0;j<N;j++)
330          x[j] *= E;
331    } else {
332       float P[N];
333       float pred_gain;
334       if (K>4)
335          pred_gain = .5;
336       else
337          pred_gain = pg[K];
338       E = 1e-10;
339       for (j=0;j<N;j++)
340       {
341          P[j] = s*Y[best+j];
342          E += P[j]*P[j];
343       }
344       E = .8/sqrt(E);
345       for (j=0;j<N;j++)
346          P[j] *= E;
347       alg_unquant(x, N, K, P, dec);
348    }
349 }