More decoding work
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include <math.h>
33 #include <stdlib.h>
34 #include "cwrs.h"
35 #include "vq.h"
36
37
38 /* Improved algebraic pulse-base quantiser. The signal x is replaced by the sum of the pitch 
39    a combination of pulses such that its norm is still equal to 1. The only difference with 
40    the quantiser above is that the search is more complete. */
41 void alg_quant(float *x, int N, int K, float *p, ec_enc *enc)
42 {
43    int L = 5;
44    //float tata[200];
45    float y[L][N];
46    int iy[L][N];
47    //float tata2[200];
48    float ny[L][N];
49    int iny[L][N];
50    int i, j, m;
51    float xy[L], nxy[L];
52    float yy[L], nyy[L];
53    float yp[L], nyp[L];
54    float best_scores[L];
55    float Rpp=0, Rxp=0;
56    float gain[L];
57    int maxL = 1;
58    float alpha = .9;
59    
60    for (j=0;j<N;j++)
61       Rpp += p[j]*p[j];
62    //if (Rpp>.01)
63    //   alpha = (1-sqrt(1-Rpp))/Rpp;
64    for (j=0;j<N;j++)
65       Rxp += x[j]*p[j];
66    for (m=0;m<L;m++)
67       for (i=0;i<N;i++)
68          y[m][i] = 0;
69       
70    for (m=0;m<L;m++)
71       for (i=0;i<N;i++)
72          ny[m][i] = 0;
73
74    for (m=0;m<L;m++)
75       for (i=0;i<N;i++)
76          iy[m][i] = iny[m][i] = 0;
77
78    for (m=0;m<L;m++)
79       xy[m] = yy[m] = yp[m] = gain[m] = 0;
80    
81    for (i=0;i<K;i++)
82    {
83       int L2 = L;
84       if (L>maxL)
85       {
86          L2 = maxL;
87          maxL *= N;
88       }
89       for (m=0;m<L;m++)
90          best_scores[m] = -1e10;
91
92       for (m=0;m<L2;m++)
93       {
94          for (j=0;j<N;j++)
95          {
96             int sign;
97             for (sign=-1;sign<=1;sign+=2)
98             {
99                if (iy[m][j]*sign < 0)
100                   continue;
101                //fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);
102                float tmp_xy, tmp_yy, tmp_yp;
103                float score;
104                float g;
105                float s = sign;
106                tmp_xy = xy[m] + s*x[j]               - alpha*s*p[j]*Rxp;
107                tmp_yy = yy[m] + 2*s*y[m][j] + 1      +alpha*alpha*p[j]*p[j]*Rpp - 2*alpha*s*p[j]*yp[m] - 2*alpha*p[j]*p[j];
108                tmp_yp = yp[m] + s*p[j]               *(1-alpha*Rpp);
109                g = (sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)/tmp_yy;
110                score = 2*g*tmp_xy - g*g*tmp_yy;
111
112                if (score>best_scores[L-1])
113                {
114                   int k, n;
115                   int id = L-1;
116                   while (id > 0 && score > best_scores[id-1])
117                      id--;
118                
119                   for (k=L-1;k>id;k--)
120                   {
121                      nxy[k] = nxy[k-1];
122                      nyy[k] = nyy[k-1];
123                      nyp[k] = nyp[k-1];
124                      //fprintf(stderr, "%d %d \n", N, k);
125                      for (n=0;n<N;n++)
126                         ny[k][n] = ny[k-1][n];
127                      for (n=0;n<N;n++)
128                         iny[k][n] = iny[k-1][n];
129                      gain[k] = gain[k-1];
130                      best_scores[k] = best_scores[k-1];
131                   }
132
133                   nxy[id] = tmp_xy;
134                   nyy[id] = tmp_yy;
135                   nyp[id] = tmp_yp;
136                   gain[id] = g;
137                   for (n=0;n<N;n++)
138                      ny[id][n] = y[m][n];
139                   ny[id][j] += s;
140                   for (n=0;n<N;n++)
141                      ny[id][n] -= alpha*s*p[j]*p[n];
142                
143                   for (n=0;n<N;n++)
144                      iny[id][n] = iy[m][n];
145                   if (s>0)
146                      iny[id][j] += 1;
147                   else
148                      iny[id][j] -= 1;
149                   best_scores[id] = score;
150                }
151             }   
152          }
153          
154       }
155       int k,n;
156       for (k=0;k<L;k++)
157       {
158          xy[k] = nxy[k];
159          yy[k] = nyy[k];
160          yp[k] = nyp[k];
161          for (n=0;n<N;n++)
162             y[k][n] = ny[k][n];
163          for (n=0;n<N;n++)
164             iy[k][n] = iny[k][n];
165       }
166
167    }
168    
169    for (i=0;i<N;i++)
170       x[i] = p[i]+gain[0]*y[0][i];
171    if (0) {
172       float E=1e-15;
173       int ABS = 0;
174       for (i=0;i<N;i++)
175          ABS += abs(iy[0][i]);
176       //if (K != ABS)
177       //   printf ("%d %d\n", K, ABS);
178       for (i=0;i<N;i++)
179          E += x[i]*x[i];
180       //printf ("%f\n", E);
181       E = 1/sqrt(E);
182       for (i=0;i<N;i++)
183          x[i] *= E;
184    }
185    int comb[K];
186    int signs[K];
187    //for (i=0;i<N;i++)
188    //   printf ("%d ", iy[0][i]);
189    pulse2comb(N, K, comb, signs, iy[0]); 
190    ec_enc_uint(enc,icwrs(N, K, comb, signs),ncwrs(N, K));
191 }
192
193 static const float pg[5] = {1.f, .82f, .75f, 0.7f, 0.6f};
194
195 /* Finds the right offset into Y and copy it */
196 void copy_quant(float *x, int N, int K, float *Y, int B, int N0, ec_enc *enc)
197 {
198    int i,j;
199    int best=0;
200    float best_score=0;
201    float s = 1;
202    float E;
203    for (i=0;i<N0*B-N;i+=B)
204    {
205       int j;
206       float xy=0, yy=0;
207       float score;
208       for (j=0;j<N;j++)
209       {
210          xy += x[j]*Y[i+j];
211          yy += Y[i+j]*Y[i+j];
212       }
213       score = xy*xy/(.001+yy);
214       if (score > best_score)
215       {
216          best_score = score;
217          best = i;
218          if (xy>0)
219             s = 1;
220          else
221             s = -1;
222       }
223    }
224    ec_enc_uint(enc,best/B,N0-N/B);
225    //printf ("%d %f\n", best, best_score);
226    if (K==0)
227    {
228       E = 1e-10;
229       for (j=0;j<N;j++)
230       {
231          x[j] = s*Y[best+j];
232          E += x[j]*x[j];
233       }
234       E = 1/sqrt(E);
235       for (j=0;j<N;j++)
236          x[j] *= E;
237    } else {
238       float P[N];
239       float pred_gain;
240       if (K>4)
241          pred_gain = .5;
242       else
243          pred_gain = pg[K];
244       E = 1e-10;
245       for (j=0;j<N;j++)
246       {
247          P[j] = s*Y[best+j];
248          E += P[j]*P[j];
249       }
250       E = .8/sqrt(E);
251       for (j=0;j<N;j++)
252          P[j] *= E;
253       alg_quant(x, N, K, P, enc);
254    }
255 }
256
257 void alg_unquant(float *x, int N, int K, float *p, ec_dec *dec)
258 {
259    int i;
260    unsigned int id;
261    int comb[K];
262    int signs[K];
263    int iy[N];
264    float y[N];
265    float alpha = .9;
266    float Rpp=0, Ryp=0, Ryy=0;
267    float g;
268    
269    id = ec_dec_uint(dec, ncwrs(N, K));
270    cwrsi(N, K, id, comb, signs);
271    comb2pulse(N, K, iy, comb, signs);
272    //for (i=0;i<N;i++)
273    //   printf ("%d ", iy[i]);
274    for (i=0;i<N;i++)
275       Rpp += p[i]*p[i];
276
277    for (i=0;i<N;i++)
278       Ryp += iy[i]*p[i];
279
280    for (i=0;i<N;i++)
281       y[i] = iy[i] - alpha*Ryp*p[i];
282    
283    /* Recompute after the projection (I think it's right) */
284    Ryp = 0;
285    for (i=0;i<N;i++)
286       Ryp += y[i]*p[i];
287    
288    for (i=0;i<N;i++)
289       Ryy += y[i]*y[i];
290
291    g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
292    
293    for (i=0;i<N;i++)
294       x[i] = p[i] + g*y[i];
295 }
296