moved pulse [en|de]coding to cwrs.c
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include <math.h>
33 #include <stdlib.h>
34 #include "cwrs.h"
35 #include "vq.h"
36
37
38 /* Improved algebraic pulse-base quantiser. The signal x is replaced by the sum of the pitch 
39    a combination of pulses such that its norm is still equal to 1. The only difference with 
40    the quantiser above is that the search is more complete. */
41 void alg_quant(float *x, float *W, int N, int K, float *p, float alpha, ec_enc *enc)
42 {
43    int L = 3;
44    //float tata[200];
45    float _y[L][N];
46    int _iy[L][N];
47    //float tata2[200];
48    float _ny[L][N];
49    int _iny[L][N];
50    float *(ny[L]), *(y[L]);
51    int *(iny[L]), *(iy[L]);
52    int i, j, m;
53    int pulsesLeft;
54    float xy[L], nxy[L];
55    float yy[L], nyy[L];
56    float yp[L], nyp[L];
57    float best_scores[L];
58    float Rpp=0, Rxp=0;
59    float gain[L];
60    int maxL = 1;
61    
62    for (m=0;m<L;m++)
63    {
64       ny[m] = _ny[m];
65       iny[m] = _iny[m];
66       y[m] = _y[m];
67       iy[m] = _iy[m];
68    }
69    
70    for (j=0;j<N;j++)
71       Rpp += p[j]*p[j];
72    //if (Rpp>.01)
73    //   alpha = (1-sqrt(1-Rpp))/Rpp;
74    for (j=0;j<N;j++)
75       Rxp += x[j]*p[j];
76    for (m=0;m<L;m++)
77       for (i=0;i<N;i++)
78          y[m][i] = 0;
79       
80    for (m=0;m<L;m++)
81       for (i=0;i<N;i++)
82          ny[m][i] = 0;
83
84    for (m=0;m<L;m++)
85       for (i=0;i<N;i++)
86          iy[m][i] = iny[m][i] = 0;
87
88    for (m=0;m<L;m++)
89       xy[m] = yy[m] = yp[m] = gain[m] = 0;
90    
91    pulsesLeft = K;
92    while (pulsesLeft > 0)
93    {
94       int pulsesAtOnce=1;
95       int Lupdate = L;
96       int L2 = L;
97       pulsesAtOnce = pulsesLeft/N;
98       if (pulsesAtOnce<1)
99          pulsesAtOnce = 1;
100       if (pulsesLeft-pulsesAtOnce > 3 || N > 30)
101          Lupdate = 1;
102       //printf ("%d %d %d/%d %d\n", Lupdate, pulsesAtOnce, pulsesLeft, K, N);
103       L2 = Lupdate;
104       if (L2>maxL)
105       {
106          L2 = maxL;
107          maxL *= N;
108       }
109
110       for (m=0;m<L;m++)
111          best_scores[m] = -1e10;
112
113       for (m=0;m<L2;m++)
114       {
115          for (j=0;j<N;j++)
116          {
117             int sign;
118             //if (x[j]>0) sign=1; else sign=-1;
119             for (sign=-1;sign<=1;sign+=2)
120             {
121                /* All pulses at one location must have the same sign. Also,
122                   only consider sign in the same direction as x[j], except for the
123                   last pulses */
124                if (iy[m][j]*sign < 0 || (x[j]*sign<0 && pulsesLeft>((K+1)>>1)))
125                   continue;
126                //fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);
127                float tmp_xy, tmp_yy, tmp_yp;
128                float score;
129                float g;
130                float s = sign*pulsesAtOnce;
131                
132                /* Updating the sums the the new pulse(s) */
133                tmp_xy = xy[m] + s*x[j]               - alpha*s*p[j]*Rxp;
134                tmp_yy = yy[m] + 2*s*y[m][j] + s*s      +s*s*alpha*alpha*p[j]*p[j]*Rpp - 2*alpha*s*p[j]*yp[m] - 2*s*s*alpha*p[j]*p[j];
135                tmp_yp = yp[m] + s*p[j]               *(1-alpha*Rpp);
136                g = (sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)/tmp_yy;
137                score = 2*g*tmp_xy - g*g*tmp_yy;
138
139                if (score>best_scores[Lupdate-1])
140                {
141                   int k, n;
142                   int id = Lupdate-1;
143                   float *tmp_ny;
144                   int *tmp_iny;
145                   
146                   tmp_ny = ny[Lupdate-1];
147                   tmp_iny = iny[Lupdate-1];
148                   while (id > 0 && score > best_scores[id-1])
149                      id--;
150                
151                   for (k=Lupdate-1;k>id;k--)
152                   {
153                      nxy[k] = nxy[k-1];
154                      nyy[k] = nyy[k-1];
155                      nyp[k] = nyp[k-1];
156                      //fprintf(stderr, "%d %d \n", N, k);
157                      ny[k] = ny[k-1];
158                      iny[k] = iny[k-1];
159                      gain[k] = gain[k-1];
160                      best_scores[k] = best_scores[k-1];
161                   }
162
163                   ny[id] = tmp_ny;
164                   iny[id] = tmp_iny;
165
166                   nxy[id] = tmp_xy;
167                   nyy[id] = tmp_yy;
168                   nyp[id] = tmp_yp;
169                   gain[id] = g;
170                   for (n=0;n<N;n++)
171                      ny[id][n] = y[m][n] - alpha*s*p[j]*p[n];
172                   ny[id][j] += s;
173
174                   for (n=0;n<N;n++)
175                      iny[id][n] = iy[m][n];
176                   if (s>0)
177                      iny[id][j] += pulsesAtOnce;
178                   else
179                      iny[id][j] -= pulsesAtOnce;
180                   best_scores[id] = score;
181                }
182             }
183          }
184
185       }
186       int k;
187       for (k=0;k<Lupdate;k++)
188       {
189          float *tmp_ny;
190          int *tmp_iny;
191
192          xy[k] = nxy[k];
193          yy[k] = nyy[k];
194          yp[k] = nyp[k];
195          
196          tmp_ny = ny[k];
197          ny[k] = y[k];
198          y[k] = tmp_ny;
199          tmp_iny = iny[k];
200          iny[k] = iy[k];
201          iy[k] = tmp_iny;
202       }
203       pulsesLeft -= pulsesAtOnce;
204    }
205    
206    if (0) {
207       float err=0;
208       for (i=0;i<N;i++)
209          err += (x[i]-gain[0]*y[0][i])*(x[i]-gain[0]*y[0][i]);
210       //if (N<=10)
211       //printf ("%f %d %d\n", err, K, N);
212    }
213    for (i=0;i<N;i++)
214       x[i] = p[i]+gain[0]*y[0][i];
215    if (0) {
216       float E=1e-15;
217       int ABS = 0;
218       for (i=0;i<N;i++)
219          ABS += abs(iy[0][i]);
220       //if (K != ABS)
221       //   printf ("%d %d\n", K, ABS);
222       for (i=0;i<N;i++)
223          E += x[i]*x[i];
224       //printf ("%f\n", E);
225       E = 1/sqrt(E);
226       for (i=0;i<N;i++)
227          x[i] *= E;
228    }
229    
230    encode_pulses(iy[0], N, K, enc);
231    
232    //printf ("%llu ", icwrs64(N, K, comb, signs));
233    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
234       due to the recursive computation used in quantisation.
235       Not quite sure whether we need that or not */
236    if (1) {
237       float Ryp=0;
238       float Rpp=0;
239       float Ryy=0;
240       float g=0;
241       for (i=0;i<N;i++)
242          Rpp += p[i]*p[i];
243       
244       for (i=0;i<N;i++)
245          Ryp += iy[0][i]*p[i];
246       
247       for (i=0;i<N;i++)
248          y[0][i] = iy[0][i] - alpha*Ryp*p[i];
249       
250       Ryp = 0;
251       for (i=0;i<N;i++)
252          Ryp += y[0][i]*p[i];
253       
254       for (i=0;i<N;i++)
255          Ryy += y[0][i]*y[0][i];
256       
257       g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
258         
259       for (i=0;i<N;i++)
260          x[i] = p[i] + g*y[0][i];
261       
262    }
263
264 }
265
266 void alg_unquant(float *x, int N, int K, float *p, float alpha, ec_dec *dec)
267 {
268    int i;
269    int iy[N];
270    float y[N];
271    float Rpp=0, Ryp=0, Ryy=0;
272    float g;
273
274    decode_pulses(iy, N, K, dec);
275
276    //for (i=0;i<N;i++)
277    //   printf ("%d ", iy[i]);
278    for (i=0;i<N;i++)
279       Rpp += p[i]*p[i];
280
281    for (i=0;i<N;i++)
282       Ryp += iy[i]*p[i];
283
284    for (i=0;i<N;i++)
285       y[i] = iy[i] - alpha*Ryp*p[i];
286
287    /* Recompute after the projection (I think it's right) */
288    Ryp = 0;
289    for (i=0;i<N;i++)
290       Ryp += y[i]*p[i];
291
292    for (i=0;i<N;i++)
293       Ryy += y[i]*y[i];
294
295    g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
296
297    for (i=0;i<N;i++)
298       x[i] = p[i] + g*y[i];
299 }
300
301
302 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
303
304 void intra_prediction(float *x, float *W, int N, int K, float *Y, float *P, int B, int N0, ec_enc *enc)
305 {
306    int i,j;
307    int best=0;
308    float best_score=0;
309    float s = 1;
310    int sign;
311    float E;
312    int max_pos = N0-N/B;
313    if (max_pos > 32)
314       max_pos = 32;
315
316    for (i=0;i<max_pos*B;i+=B)
317    {
318       int j;
319       float xy=0, yy=0;
320       float score;
321       for (j=0;j<N;j++)
322       {
323          xy += x[j]*Y[i+j];
324          yy += Y[i+j]*Y[i+j];
325       }
326       score = xy*xy/(.001+yy);
327       if (score > best_score)
328       {
329          best_score = score;
330          best = i;
331          if (xy>0)
332             s = 1;
333          else
334             s = -1;
335       }
336    }
337    if (s<0)
338       sign = 1;
339    else
340       sign = 0;
341    //printf ("%d %d ", sign, best);
342    ec_enc_uint(enc,sign,2);
343    ec_enc_uint(enc,best/B,max_pos);
344    //printf ("%d %f\n", best, best_score);
345    
346    float pred_gain;
347    if (K>10)
348       pred_gain = pg[10];
349    else
350       pred_gain = pg[K];
351    E = 1e-10;
352    for (j=0;j<N;j++)
353    {
354       P[j] = s*Y[best+j];
355       E += P[j]*P[j];
356    }
357    E = pred_gain/sqrt(E);
358    for (j=0;j<N;j++)
359       P[j] *= E;
360    if (K>0)
361    {
362       for (j=0;j<N;j++)
363          x[j] -= P[j];
364    } else {
365       for (j=0;j<N;j++)
366          x[j] = P[j];
367    }
368    //printf ("quant ");
369    //for (j=0;j<N;j++) printf ("%f ", P[j]);
370
371 }
372
373 void intra_unquant(float *x, int N, int K, float *Y, float *P, int B, int N0, ec_dec *dec)
374 {
375    int j;
376    int sign;
377    float s;
378    int best;
379    float E;
380    int max_pos = N0-N/B;
381    if (max_pos > 32)
382       max_pos = 32;
383    
384    sign = ec_dec_uint(dec, 2);
385    if (sign == 0)
386       s = 1;
387    else
388       s = -1;
389    
390    best = B*ec_dec_uint(dec, max_pos);
391    //printf ("%d %d ", sign, best);
392
393    float pred_gain;
394    if (K>10)
395       pred_gain = pg[10];
396    else
397       pred_gain = pg[K];
398    E = 1e-10;
399    for (j=0;j<N;j++)
400    {
401       P[j] = s*Y[best+j];
402       E += P[j]*P[j];
403    }
404    E = pred_gain/sqrt(E);
405    for (j=0;j<N;j++)
406       P[j] *= E;
407    if (K==0)
408    {
409       for (j=0;j<N;j++)
410          x[j] = P[j];
411    }
412 }
413
414 void intra_fold(float *x, int N, float *Y, float *P, int B, int N0, int Nmax)
415 {
416    int i, j;
417    float E;
418    
419    E = 1e-10;
420    if (N0 >= Nmax/2)
421    {
422       for (i=0;i<B;i++)
423       {
424          for (j=0;j<N/B;j++)
425          {
426             P[j*B+i] = Y[(Nmax-N0-j-1)*B+i];
427             E += P[j*B+i]*P[j*B+i];
428          }
429       }
430    } else {
431       for (j=0;j<N;j++)
432       {
433          P[j] = Y[j];
434          E += P[j]*P[j];
435       }
436    }
437    E = 1.f/sqrt(E);
438    for (j=0;j<N;j++)
439       P[j] *= E;
440    for (j=0;j<N;j++)
441       x[j] = P[j];
442 }
443