Limiting intra-frame prediction codebook to 32 entries (plus sign)
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include <math.h>
33 #include <stdlib.h>
34 #include "cwrs.h"
35 #include "vq.h"
36
37
38 /* Improved algebraic pulse-base quantiser. The signal x is replaced by the sum of the pitch 
39    a combination of pulses such that its norm is still equal to 1. The only difference with 
40    the quantiser above is that the search is more complete. */
41 void alg_quant(float *x, float *W, int N, int K, float *p, float alpha, ec_enc *enc)
42 {
43    int L = 3;
44    //float tata[200];
45    float y[L][N];
46    int iy[L][N];
47    //float tata2[200];
48    float ny[L][N];
49    int iny[L][N];
50    int i, j, m;
51    float xy[L], nxy[L];
52    float yy[L], nyy[L];
53    float yp[L], nyp[L];
54    float best_scores[L];
55    float Rpp=0, Rxp=0;
56    float gain[L];
57    int maxL = 1;
58    
59    for (j=0;j<N;j++)
60       Rpp += p[j]*p[j];
61    //if (Rpp>.01)
62    //   alpha = (1-sqrt(1-Rpp))/Rpp;
63    for (j=0;j<N;j++)
64       Rxp += x[j]*p[j];
65    for (m=0;m<L;m++)
66       for (i=0;i<N;i++)
67          y[m][i] = 0;
68       
69    for (m=0;m<L;m++)
70       for (i=0;i<N;i++)
71          ny[m][i] = 0;
72
73    for (m=0;m<L;m++)
74       for (i=0;i<N;i++)
75          iy[m][i] = iny[m][i] = 0;
76
77    for (m=0;m<L;m++)
78       xy[m] = yy[m] = yp[m] = gain[m] = 0;
79    
80    for (i=0;i<K;i++)
81    {
82       int L2 = L;
83       if (L>maxL)
84       {
85          L2 = maxL;
86          maxL *= N;
87       }
88       for (m=0;m<L;m++)
89          best_scores[m] = -1e10;
90
91       for (m=0;m<L2;m++)
92       {
93          for (j=0;j<N;j++)
94          {
95             int sign;
96             for (sign=-1;sign<=1;sign+=2)
97             {
98                if (iy[m][j]*sign < 0)
99                   continue;
100                //fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);
101                float tmp_xy, tmp_yy, tmp_yp;
102                float score;
103                float g;
104                float s = sign;
105                tmp_xy = xy[m] + s*x[j]               - alpha*s*p[j]*Rxp;
106                tmp_yy = yy[m] + 2*s*y[m][j] + 1      +alpha*alpha*p[j]*p[j]*Rpp - 2*alpha*s*p[j]*yp[m] - 2*alpha*p[j]*p[j];
107                tmp_yp = yp[m] + s*p[j]               *(1-alpha*Rpp);
108                g = (sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)/tmp_yy;
109                score = 2*g*tmp_xy - g*g*tmp_yy;
110
111                if (score>best_scores[L-1])
112                {
113                   int k, n;
114                   int id = L-1;
115                   while (id > 0 && score > best_scores[id-1])
116                      id--;
117                
118                   for (k=L-1;k>id;k--)
119                   {
120                      nxy[k] = nxy[k-1];
121                      nyy[k] = nyy[k-1];
122                      nyp[k] = nyp[k-1];
123                      //fprintf(stderr, "%d %d \n", N, k);
124                      for (n=0;n<N;n++)
125                         ny[k][n] = ny[k-1][n];
126                      for (n=0;n<N;n++)
127                         iny[k][n] = iny[k-1][n];
128                      gain[k] = gain[k-1];
129                      best_scores[k] = best_scores[k-1];
130                   }
131
132                   nxy[id] = tmp_xy;
133                   nyy[id] = tmp_yy;
134                   nyp[id] = tmp_yp;
135                   gain[id] = g;
136                   for (n=0;n<N;n++)
137                      ny[id][n] = y[m][n];
138                   ny[id][j] += s;
139                   for (n=0;n<N;n++)
140                      ny[id][n] -= alpha*s*p[j]*p[n];
141                
142                   for (n=0;n<N;n++)
143                      iny[id][n] = iy[m][n];
144                   if (s>0)
145                      iny[id][j] += 1;
146                   else
147                      iny[id][j] -= 1;
148                   best_scores[id] = score;
149                }
150             }   
151          }
152          
153       }
154       int k,n;
155       for (k=0;k<L;k++)
156       {
157          xy[k] = nxy[k];
158          yy[k] = nyy[k];
159          yp[k] = nyp[k];
160          for (n=0;n<N;n++)
161             y[k][n] = ny[k][n];
162          for (n=0;n<N;n++)
163             iy[k][n] = iny[k][n];
164       }
165
166    }
167    
168    for (i=0;i<N;i++)
169       x[i] = p[i]+gain[0]*y[0][i];
170    if (0) {
171       float E=1e-15;
172       int ABS = 0;
173       for (i=0;i<N;i++)
174          ABS += abs(iy[0][i]);
175       //if (K != ABS)
176       //   printf ("%d %d\n", K, ABS);
177       for (i=0;i<N;i++)
178          E += x[i]*x[i];
179       //printf ("%f\n", E);
180       E = 1/sqrt(E);
181       for (i=0;i<N;i++)
182          x[i] *= E;
183    }
184    int comb[K];
185    int signs[K];
186    //for (i=0;i<N;i++)
187    //   printf ("%d ", iy[0][i]);
188    pulse2comb(N, K, comb, signs, iy[0]); 
189    ec_enc_uint64(enc,icwrs64(N, K, comb, signs),ncwrs64(N, K));
190    //printf ("%llu ", icwrs64(N, K, comb, signs));
191    /* Recompute the gain in one pass to reduce the encoder-decoder mismatch
192       due to the recursive computation used in quantisation */
193    if (1) {
194       float Ryp=0;
195       float Rpp=0;
196       float Ryy=0;
197       float g=0;
198       for (i=0;i<N;i++)
199          Rpp += p[i]*p[i];
200       
201       for (i=0;i<N;i++)
202          Ryp += iy[0][i]*p[i];
203       
204       for (i=0;i<N;i++)
205          y[0][i] = iy[0][i] - alpha*Ryp*p[i];
206       
207       Ryp = 0;
208       for (i=0;i<N;i++)
209          Ryp += y[0][i]*p[i];
210       
211       for (i=0;i<N;i++)
212          Ryy += y[0][i]*y[0][i];
213       
214       g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
215         
216       for (i=0;i<N;i++)
217          x[i] = p[i] + g*y[0][i];
218       
219    }
220
221 }
222
223 void alg_unquant(float *x, int N, int K, float *p, float alpha, ec_dec *dec)
224 {
225    int i;
226    celt_uint64_t id;
227    int comb[K];
228    int signs[K];
229    int iy[N];
230    float y[N];
231    float Rpp=0, Ryp=0, Ryy=0;
232    float g;
233
234    id = ec_dec_uint64(dec, ncwrs64(N, K));
235    //printf ("%llu ", id);
236    cwrsi64(N, K, id, comb, signs);
237    comb2pulse(N, K, iy, comb, signs);
238    //for (i=0;i<N;i++)
239    //   printf ("%d ", iy[i]);
240    for (i=0;i<N;i++)
241       Rpp += p[i]*p[i];
242
243    for (i=0;i<N;i++)
244       Ryp += iy[i]*p[i];
245
246    for (i=0;i<N;i++)
247       y[i] = iy[i] - alpha*Ryp*p[i];
248
249    /* Recompute after the projection (I think it's right) */
250    Ryp = 0;
251    for (i=0;i<N;i++)
252       Ryp += y[i]*p[i];
253
254    for (i=0;i<N;i++)
255       Ryy += y[i]*y[i];
256
257    g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
258
259    for (i=0;i<N;i++)
260       x[i] = p[i] + g*y[i];
261 }
262
263
264 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
265
266 void intra_prediction(float *x, float *W, int N, int K, float *Y, float *P, int B, int N0, ec_enc *enc)
267 {
268    int i,j;
269    int best=0;
270    float best_score=0;
271    float s = 1;
272    int sign;
273    float E;
274    int max_pos = N0-N/B;
275    if (max_pos > 32)
276       max_pos = 32;
277
278    for (i=0;i<max_pos*B;i+=B)
279    {
280       int j;
281       float xy=0, yy=0;
282       float score;
283       for (j=0;j<N;j++)
284       {
285          xy += x[j]*Y[i+j];
286          yy += Y[i+j]*Y[i+j];
287       }
288       score = xy*xy/(.001+yy);
289       if (score > best_score)
290       {
291          best_score = score;
292          best = i;
293          if (xy>0)
294             s = 1;
295          else
296             s = -1;
297       }
298    }
299    if (s<0)
300       sign = 1;
301    else
302       sign = 0;
303    //printf ("%d %d ", sign, best);
304    ec_enc_uint(enc,sign,2);
305    ec_enc_uint(enc,best/B,max_pos);
306    //printf ("%d %f\n", best, best_score);
307    
308    float pred_gain;
309    if (K>10)
310       pred_gain = pg[10];
311    else
312       pred_gain = pg[K];
313    E = 1e-10;
314    for (j=0;j<N;j++)
315    {
316       P[j] = s*Y[best+j];
317       E += P[j]*P[j];
318    }
319    E = pred_gain/sqrt(E);
320    for (j=0;j<N;j++)
321       P[j] *= E;
322    if (K>0)
323    {
324       for (j=0;j<N;j++)
325          x[j] -= P[j];
326    } else {
327       for (j=0;j<N;j++)
328          x[j] = P[j];
329    }
330    //printf ("quant ");
331    //for (j=0;j<N;j++) printf ("%f ", P[j]);
332
333 }
334
335 void intra_unquant(float *x, int N, int K, float *Y, float *P, int B, int N0, ec_dec *dec)
336 {
337    int j;
338    int sign;
339    float s;
340    int best;
341    float E;
342    int max_pos = N0-N/B;
343    if (max_pos > 32)
344       max_pos = 32;
345    
346    sign = ec_dec_uint(dec, 2);
347    if (sign == 0)
348       s = 1;
349    else
350       s = -1;
351    
352    best = B*ec_dec_uint(dec, max_pos);
353    //printf ("%d %d ", sign, best);
354
355    float pred_gain;
356    if (K>10)
357       pred_gain = pg[10];
358    else
359       pred_gain = pg[K];
360    E = 1e-10;
361    for (j=0;j<N;j++)
362    {
363       P[j] = s*Y[best+j];
364       E += P[j]*P[j];
365    }
366    E = pred_gain/sqrt(E);
367    for (j=0;j<N;j++)
368       P[j] *= E;
369    if (K==0)
370    {
371       for (j=0;j<N;j++)
372          x[j] = P[j];
373    }
374 }