Got the intra-band prediction/copy to work correctly with
[opus.git] / libcelt / vq.c
1 /* (C) 2007 Jean-Marc Valin, CSIRO
2 */
3 /*
4    Redistribution and use in source and binary forms, with or without
5    modification, are permitted provided that the following conditions
6    are met:
7    
8    - Redistributions of source code must retain the above copyright
9    notice, this list of conditions and the following disclaimer.
10    
11    - Redistributions in binary form must reproduce the above copyright
12    notice, this list of conditions and the following disclaimer in the
13    documentation and/or other materials provided with the distribution.
14    
15    - Neither the name of the Xiph.org Foundation nor the names of its
16    contributors may be used to endorse or promote products derived from
17    this software without specific prior written permission.
18    
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 #include <math.h>
33 #include <stdlib.h>
34 #include "cwrs.h"
35 #include "vq.h"
36
37
38 /* Improved algebraic pulse-base quantiser. The signal x is replaced by the sum of the pitch 
39    a combination of pulses such that its norm is still equal to 1. The only difference with 
40    the quantiser above is that the search is more complete. */
41 void alg_quant(float *x, float *W, int N, int K, float *p, float alpha, ec_enc *enc)
42 {
43    int L = 3;
44    //float tata[200];
45    float y[L][N];
46    int iy[L][N];
47    //float tata2[200];
48    float ny[L][N];
49    int iny[L][N];
50    int i, j, m;
51    float xy[L], nxy[L];
52    float yy[L], nyy[L];
53    float yp[L], nyp[L];
54    float best_scores[L];
55    float Rpp=0, Rxp=0;
56    float gain[L];
57    int maxL = 1;
58    
59    for (j=0;j<N;j++)
60       Rpp += p[j]*p[j];
61    //if (Rpp>.01)
62    //   alpha = (1-sqrt(1-Rpp))/Rpp;
63    for (j=0;j<N;j++)
64       Rxp += x[j]*p[j];
65    for (m=0;m<L;m++)
66       for (i=0;i<N;i++)
67          y[m][i] = 0;
68       
69    for (m=0;m<L;m++)
70       for (i=0;i<N;i++)
71          ny[m][i] = 0;
72
73    for (m=0;m<L;m++)
74       for (i=0;i<N;i++)
75          iy[m][i] = iny[m][i] = 0;
76
77    for (m=0;m<L;m++)
78       xy[m] = yy[m] = yp[m] = gain[m] = 0;
79    
80    for (i=0;i<K;i++)
81    {
82       int L2 = L;
83       if (L>maxL)
84       {
85          L2 = maxL;
86          maxL *= N;
87       }
88       for (m=0;m<L;m++)
89          best_scores[m] = -1e10;
90
91       for (m=0;m<L2;m++)
92       {
93          for (j=0;j<N;j++)
94          {
95             int sign;
96             for (sign=-1;sign<=1;sign+=2)
97             {
98                if (iy[m][j]*sign < 0)
99                   continue;
100                //fprintf (stderr, "%d/%d %d/%d %d/%d\n", i, K, m, L2, j, N);
101                float tmp_xy, tmp_yy, tmp_yp;
102                float score;
103                float g;
104                float s = sign;
105                tmp_xy = xy[m] + s*x[j]               - alpha*s*p[j]*Rxp;
106                tmp_yy = yy[m] + 2*s*y[m][j] + 1      +alpha*alpha*p[j]*p[j]*Rpp - 2*alpha*s*p[j]*yp[m] - 2*alpha*p[j]*p[j];
107                tmp_yp = yp[m] + s*p[j]               *(1-alpha*Rpp);
108                g = (sqrt(tmp_yp*tmp_yp + tmp_yy - tmp_yy*Rpp) - tmp_yp)/tmp_yy;
109                score = 2*g*tmp_xy - g*g*tmp_yy;
110
111                if (score>best_scores[L-1])
112                {
113                   int k, n;
114                   int id = L-1;
115                   while (id > 0 && score > best_scores[id-1])
116                      id--;
117                
118                   for (k=L-1;k>id;k--)
119                   {
120                      nxy[k] = nxy[k-1];
121                      nyy[k] = nyy[k-1];
122                      nyp[k] = nyp[k-1];
123                      //fprintf(stderr, "%d %d \n", N, k);
124                      for (n=0;n<N;n++)
125                         ny[k][n] = ny[k-1][n];
126                      for (n=0;n<N;n++)
127                         iny[k][n] = iny[k-1][n];
128                      gain[k] = gain[k-1];
129                      best_scores[k] = best_scores[k-1];
130                   }
131
132                   nxy[id] = tmp_xy;
133                   nyy[id] = tmp_yy;
134                   nyp[id] = tmp_yp;
135                   gain[id] = g;
136                   for (n=0;n<N;n++)
137                      ny[id][n] = y[m][n];
138                   ny[id][j] += s;
139                   for (n=0;n<N;n++)
140                      ny[id][n] -= alpha*s*p[j]*p[n];
141                
142                   for (n=0;n<N;n++)
143                      iny[id][n] = iy[m][n];
144                   if (s>0)
145                      iny[id][j] += 1;
146                   else
147                      iny[id][j] -= 1;
148                   best_scores[id] = score;
149                }
150             }   
151          }
152          
153       }
154       int k,n;
155       for (k=0;k<L;k++)
156       {
157          xy[k] = nxy[k];
158          yy[k] = nyy[k];
159          yp[k] = nyp[k];
160          for (n=0;n<N;n++)
161             y[k][n] = ny[k][n];
162          for (n=0;n<N;n++)
163             iy[k][n] = iny[k][n];
164       }
165
166    }
167    
168    for (i=0;i<N;i++)
169       x[i] = p[i]+gain[0]*y[0][i];
170    if (0) {
171       float E=1e-15;
172       int ABS = 0;
173       for (i=0;i<N;i++)
174          ABS += abs(iy[0][i]);
175       //if (K != ABS)
176       //   printf ("%d %d\n", K, ABS);
177       for (i=0;i<N;i++)
178          E += x[i]*x[i];
179       //printf ("%f\n", E);
180       E = 1/sqrt(E);
181       for (i=0;i<N;i++)
182          x[i] *= E;
183    }
184    int comb[K];
185    int signs[K];
186    //for (i=0;i<N;i++)
187    //   printf ("%d ", iy[0][i]);
188    pulse2comb(N, K, comb, signs, iy[0]); 
189    ec_enc_uint64(enc,icwrs64(N, K, comb, signs),ncwrs64(N, K));
190    
191    /* Recompute the gain in one pass (to reduce errors) */
192    if (0) {
193       float Ryp=0;
194       float Rpp=0;
195       float Ryy=0;
196       float g=0;
197       for (i=0;i<N;i++)
198          Rpp += p[i]*p[i];
199       
200       for (i=0;i<N;i++)
201          Ryp += iy[0][i]*p[i];
202       
203       for (i=0;i<N;i++)
204          y[0][i] = iy[0][i] - alpha*Ryp*p[i];
205       
206       /* Recompute after the projection (I think it's right) */
207       Ryp = 0;
208       for (i=0;i<N;i++)
209          Ryp += y[0][i]*p[i];
210       
211       for (i=0;i<N;i++)
212          Ryy += y[0][i]*y[0][i];
213       
214       g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
215         
216       for (i=0;i<N;i++)
217          x[i] = p[i] + g*y[0][i];
218       
219    }
220
221 }
222
223 void alg_unquant(float *x, int N, int K, float *p, float alpha, ec_dec *dec)
224 {
225    int i;
226    celt_uint64_t id;
227    int comb[K];
228    int signs[K];
229    int iy[N];
230    float y[N];
231    float Rpp=0, Ryp=0, Ryy=0;
232    float g;
233
234    id = ec_dec_uint64(dec, ncwrs64(N, K));
235    cwrsi64(N, K, id, comb, signs);
236    comb2pulse(N, K, iy, comb, signs);
237    //for (i=0;i<N;i++)
238    //   printf ("%d ", iy[i]);
239    for (i=0;i<N;i++)
240       Rpp += p[i]*p[i];
241
242    for (i=0;i<N;i++)
243       Ryp += iy[i]*p[i];
244
245    for (i=0;i<N;i++)
246       y[i] = iy[i] - alpha*Ryp*p[i];
247
248    /* Recompute after the projection (I think it's right) */
249    Ryp = 0;
250    for (i=0;i<N;i++)
251       Ryp += y[i]*p[i];
252
253    for (i=0;i<N;i++)
254       Ryy += y[i]*y[i];
255
256    g = (sqrt(Ryp*Ryp + Ryy - Ryy*Rpp) - Ryp)/Ryy;
257
258    for (i=0;i<N;i++)
259       x[i] = p[i] + g*y[i];
260 }
261
262
263 static const float pg[11] = {1.f, .75f, .65f, 0.6f, 0.6f, .6f, .55f, .55f, .5f, .5f, .5f};
264
265 void intra_prediction(float *x, float *W, int N, int K, float *Y, float *P, int B, int N0, ec_enc *enc)
266 {
267    int i,j;
268    int best=0;
269    float best_score=0;
270    float s = 1;
271    int sign;
272    float E;
273    for (i=0;i<N0*B-N;i+=B)
274    {
275       int j;
276       float xy=0, yy=0;
277       float score;
278       for (j=0;j<N;j++)
279       {
280          xy += x[j]*Y[i+j];
281          yy += Y[i+j]*Y[i+j];
282       }
283       score = xy*xy/(.001+yy);
284       if (score > best_score)
285       {
286          best_score = score;
287          best = i;
288          if (xy>0)
289             s = 1;
290          else
291             s = -1;
292       }
293    }
294    if (s<0)
295       sign = 1;
296    else
297       sign = 0;
298    //printf ("%d %d ", sign, best);
299    ec_enc_uint(enc,sign,2);
300    ec_enc_uint(enc,best/B,N0-N/B);
301    //printf ("%d %f\n", best, best_score);
302    
303    float pred_gain;
304    if (K>10)
305       pred_gain = pg[10];
306    else
307       pred_gain = pg[K];
308    E = 1e-10;
309    for (j=0;j<N;j++)
310    {
311       P[j] = s*Y[best+j];
312       E += P[j]*P[j];
313    }
314    E = pred_gain/sqrt(E);
315    for (j=0;j<N;j++)
316       P[j] *= E;
317    if (K>0)
318    {
319       for (j=0;j<N;j++)
320          x[j] -= P[j];
321    } else {
322       for (j=0;j<N;j++)
323          x[j] = P[j];
324    }
325    //printf ("quant ");
326    //for (j=0;j<N;j++) printf ("%f ", P[j]);
327
328 }
329
330 void intra_unquant(float *x, int N, int K, float *Y, float *P, int B, int N0, ec_dec *dec)
331 {
332    int j;
333    int sign;
334    float s;
335    int best;
336    float E;
337    sign = ec_dec_uint(dec, 2);
338    if (sign == 0)
339       s = 1;
340    else
341       s = -1;
342    
343    best = B*ec_dec_uint(dec, N0-N/B);
344    //printf ("%d %d ", sign, best);
345
346    float pred_gain;
347    if (K>10)
348       pred_gain = pg[10];
349    else
350       pred_gain = pg[K];
351    E = 1e-10;
352    for (j=0;j<N;j++)
353    {
354       P[j] = s*Y[best+j];
355       E += P[j]*P[j];
356    }
357    E = pred_gain/sqrt(E);
358    for (j=0;j<N;j++)
359       P[j] *= E;
360    if (K==0)
361    {
362       for (j=0;j<N;j++)
363          x[j] = P[j];
364    }
365 }