Some work towards being able to encode a 48 kHz stream from 32 kHz audio (incomplete)
[opus.git] / libcelt / bands.c
index 4b49b28..fd5ce89 100644 (file)
 
 #ifdef FIXED_POINT
 /* Compute the amplitude (sqrt energy) in each of the bands */
-void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int _C, int M)
+void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
 {
    int i, c, N;
    const celt_int16 *eBands = m->eBands;
    const int C = CHANNELS(_C);
-   N = M*m->eBands[m->nbEBands+1];
+   N = M*m->shortMdctSize;
    for (c=0;c<C;c++)
    {
-      for (i=0;i<m->nbEBands;i++)
+      for (i=0;i<end;i++)
       {
          int j;
          celt_word32 maxval=0;
@@ -87,12 +87,12 @@ void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank
 }
 
 /* Normalise each band such that the energy is one. */
-void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int _C, int M)
+void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
 {
    int i, c, N;
    const celt_int16 *eBands = m->eBands;
    const int C = CHANNELS(_C);
-   N = M*m->eBands[m->nbEBands+1];
+   N = M*m->shortMdctSize;
    for (c=0;c<C;c++)
    {
       i=0; do {
@@ -105,21 +105,21 @@ void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_nor
          j=M*eBands[i]; do {
             X[j+c*N] = MULT16_16_Q15(VSHR32(freq[j+c*N],shift-1),g);
          } while (++j<M*eBands[i+1]);
-      } while (++i<m->nbEBands);
+      } while (++i<end);
    }
 }
 
 #else /* FIXED_POINT */
 /* Compute the amplitude (sqrt energy) in each of the bands */
-void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int _C, int M)
+void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank, int end, int _C, int M)
 {
    int i, c, N;
    const celt_int16 *eBands = m->eBands;
    const int C = CHANNELS(_C);
-   N = M*m->eBands[m->nbEBands+1];
+   N = M*m->shortMdctSize;
    for (c=0;c<C;c++)
    {
-      for (i=0;i<m->nbEBands;i++)
+      for (i=0;i<end;i++)
       {
          int j;
          celt_word32 sum = 1e-10;
@@ -133,15 +133,15 @@ void compute_band_energies(const CELTMode *m, const celt_sig *X, celt_ener *bank
 }
 
 /* Normalise each band such that the energy is one. */
-void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int _C, int M)
+void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_norm * restrict X, const celt_ener *bank, int end, int _C, int M)
 {
    int i, c, N;
    const celt_int16 *eBands = m->eBands;
    const int C = CHANNELS(_C);
-   N = M*m->eBands[m->nbEBands+1];
+   N = M*m->shortMdctSize;
    for (c=0;c<C;c++)
    {
-      for (i=0;i<m->nbEBands;i++)
+      for (i=0;i<end;i++)
       {
          int j;
          celt_word16 g = 1.f/(1e-10f+bank[i+c*m->nbEBands]);
@@ -153,7 +153,7 @@ void normalise_bands(const CELTMode *m, const celt_sig * restrict freq, celt_nor
 
 #endif /* FIXED_POINT */
 
-void renormalise_bands(const CELTMode *m, celt_norm * restrict X, int _C, int M)
+void renormalise_bands(const CELTMode *m, celt_norm * restrict X, int end, int _C, int M)
 {
    int i, c;
    const celt_int16 *eBands = m->eBands;
@@ -161,18 +161,18 @@ void renormalise_bands(const CELTMode *m, celt_norm * restrict X, int _C, int M)
    for (c=0;c<C;c++)
    {
       i=0; do {
-         renormalise_vector(X+M*eBands[i]+c*M*eBands[m->nbEBands+1], Q15ONE, M*eBands[i+1]-M*eBands[i], 1);
-      } while (++i<m->nbEBands);
+         renormalise_vector(X+M*eBands[i]+c*M*m->shortMdctSize, Q15ONE, M*eBands[i+1]-M*eBands[i], 1);
+      } while (++i<end);
    }
 }
 
 /* De-normalise the energy to produce the synthesis from the unit-energy bands */
-void denormalise_bands(const CELTMode *m, const celt_norm * restrict X, celt_sig * restrict freq, const celt_ener *bank, int _C, int M)
+void denormalise_bands(const CELTMode *m, const celt_norm * restrict X, celt_sig * restrict freq, const celt_ener *bank, int end, int _C, int M)
 {
    int i, c, N;
    const celt_int16 *eBands = m->eBands;
    const int C = CHANNELS(_C);
-   N = M*m->eBands[m->nbEBands+1];
+   N = M*m->shortMdctSize;
    if (C>2)
       celt_fatal("denormalise_bands() not implemented for >2 channels");
    for (c=0;c<C;c++)
@@ -181,18 +181,18 @@ void denormalise_bands(const CELTMode *m, const celt_norm * restrict X, celt_sig
       const celt_norm * restrict x;
       f = freq+c*N;
       x = X+c*N;
-      for (i=0;i<m->nbEBands;i++)
+      for (i=0;i<end;i++)
       {
-         int j, end;
+         int j, band_end;
          celt_word32 g = SHR32(bank[i+c*m->nbEBands],1);
          j=M*eBands[i];
-         end = M*eBands[i+1];
+         band_end = M*eBands[i+1];
          do {
             *f++ = SHL32(MULT16_32_Q15(*x, g),2);
             x++;
-         } while (++j<end);
+         } while (++j<band_end);
       }
-      for (i=M*eBands[m->nbEBands];i<M*eBands[m->nbEBands+1];i++)
+      for (i=M*eBands[m->nbEBands];i<N;i++)
          *f++ = 0;
    }
 }
@@ -205,7 +205,7 @@ int compute_pitch_gain(const CELTMode *m, const celt_sig *X, const celt_sig *P,
    const int C = CHANNELS(_C);
    celt_word32 Sxy=0, Sxx=0, Syy=0;
    int len = M*m->pitchEnd;
-   int N = M*m->eBands[m->nbEBands+1];
+   int N = M*m->shortMdctSize;
 #ifdef FIXED_POINT
    int shift = 0;
    celt_word32 maxabs=0;
@@ -298,7 +298,7 @@ void apply_pitch(const CELTMode *m, celt_sig *X, const celt_sig *P, int gain_id,
    const int C = CHANNELS(_C);
    int len = M*m->pitchEnd;
 
-   N = M*m->eBands[m->nbEBands+1];
+   N = M*m->shortMdctSize;
    gain = ADD16(QCONST16(.5f,14), MULT16_16_16(QCONST16(.05f,14),gain_id));
    delta = PDIV32_16(gain, len);
    if (pred)
@@ -349,7 +349,7 @@ static void stereo_band_mix(const CELTMode *m, celt_norm *X, celt_norm *Y, const
 }
 
 
-int folding_decision(const CELTMode *m, celt_norm *X, celt_word16 *average, int *last_decision, int _C, int M)
+int folding_decision(const CELTMode *m, celt_norm *X, celt_word16 *average, int *last_decision, int end, int _C, int M)
 {
    int i, c, N0;
    int NR=0;
@@ -357,11 +357,11 @@ int folding_decision(const CELTMode *m, celt_norm *X, celt_word16 *average, int
    const int C = CHANNELS(_C);
    const celt_int16 * restrict eBands = m->eBands;
    
-   N0 = M*m->eBands[m->nbEBands+1];
+   N0 = M*m->shortMdctSize;
 
    for (c=0;c<C;c++)
    {
-   for (i=0;i<m->nbEBands;i++)
+   for (i=0;i<end;i++)
    {
       int j, N;
       int max_i=0;
@@ -469,7 +469,7 @@ static void haar1(celt_norm *X, int N0, int stride)
    in two and transmit the energy difference with the two half-bands. It
    can be called recursively so bands can end up being split in 8 parts. */
 static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_norm *Y,
-      int N, int b, int spread, int tf_change, celt_norm *lowband, int resynth, ec_enc *ec,
+      int N, int b, int spread, int tf_change, celt_norm *lowband, int resynth, void *ec,
       celt_int32 *remaining_bits, int LM, celt_norm *lowband_out, const celt_ener *bandE, int level)
 {
    int q;
@@ -502,7 +502,7 @@ static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_
             if (encode)
             {
                sign = x[0]<0;
-               ec_enc_bits(ec, sign, 1);
+               ec_enc_bits((ec_enc*)ec, sign, 1);
             } else {
                sign = ec_dec_bits((ec_dec*)ec, 1);
             }
@@ -553,6 +553,8 @@ static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_
       spread0 = spread;
       N_B0 = N_B;
    }
+
+   /* Reorganize the samples in time order instead of frequency order */
    if (!stereo && spread0>1 && level==0)
    {
       if (encode)
@@ -583,6 +585,8 @@ static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_
       celt_word16 mid, side;
       int offset, N2;
       offset = m->logN[i]+(LM<<BITRES)-QTHETA_OFFSET;
+
+      /* Decide on the resolution to give to the split parameter theta */
       N2 = 2*N-1;
       if (stereo && N>2)
          N2--;
@@ -609,8 +613,12 @@ static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_
             mid = renormalise_vector(X, Q15ONE, N, 1);
             side = renormalise_vector(Y, Q15ONE, N, 1);
 
-            /* 0.63662 = 2/pi */
+            /* theta is the atan() of the ration between the (normalized)
+               side and mid. With just that parameter, we can re-scale both
+               mid and side because we know that 1) they have unit norm and
+               2) they are orthogonal. */
    #ifdef FIXED_POINT
+            /* 0.63662 = 2/pi */
             itheta = MULT16_16_Q15(QCONST16(0.63662f,15),celt_atan2p(side, mid));
    #else
             itheta = floor(.5f+16384*0.63662f*atan2(side,mid));
@@ -624,7 +632,7 @@ static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_
          if (stereo || qb>9 || spread>1)
          {
             if (encode)
-               ec_enc_uint(ec, itheta, (1<<qb)+1);
+               ec_enc_uint((ec_enc*)ec, itheta, (1<<qb)+1);
             else
                itheta = ec_dec_uint((ec_dec*)ec, (1<<qb)+1);
             qalloc = log2_frac((1<<qb)+1,BITRES);
@@ -647,7 +655,7 @@ static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_
                      fs--;
                   j++;
                }
-               ec_encode(ec, fl, fl+fs, ft);
+               ec_encode((ec_enc*)ec, fl, fl+fs, ft);
             } else {
                int fl=0;
                int j, fm;
@@ -685,6 +693,8 @@ static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_
       } else {
          imid = bitexact_cos(itheta);
          iside = bitexact_cos(16384-itheta);
+         /* This is the mid vs side allocation that minimizes squared error
+            in that band. */
          delta = (N-1)*(log2_frac(iside,BITRES+2)-log2_frac(imid,BITRES+2))>>2;
       }
 
@@ -734,7 +744,7 @@ static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_
          {
             if (encode)
             {
-               ec_enc_bits(ec, sign==1, 1);
+               ec_enc_bits((ec_enc*)ec, sign==1, 1);
             } else {
                sign = 2*ec_dec_bits((ec_dec*)ec, 1)-1;
             }
@@ -801,11 +811,12 @@ static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_
       }
 
       if (encode)
-         alg_quant(X, N, q, spread, lowband, resynth, ec);
+         alg_quant(X, N, q, spread, lowband, resynth, (ec_enc*)ec);
       else
          alg_unquant(X, N, q, spread, lowband, (ec_dec*)ec);
    }
 
+   /* This code is used by the decoder and by the resynthesis-enabled encoder */
    if (resynth)
    {
       int k;
@@ -834,6 +845,7 @@ static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_
             interleave_vector(lowband, N_B, spread0);
       }
 
+      /* Undo time-freq changes that we did earlier */
       N_B = N_B0;
       spread = spread0;
       for (k=0;k<time_divide;k++)
@@ -872,9 +884,10 @@ static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_
    }
 }
 
-void quant_all_bands(int encode, const CELTMode *m, int start, celt_norm *_X, celt_norm *_Y, const celt_ener *bandE, int *pulses, int shortBlocks, int fold, int *tf_res, int resynth, int total_bits, ec_enc *ec, int LM)
+void quant_all_bands(int encode, const CELTMode *m, int start, int end, celt_norm *_X, celt_norm *_Y, const celt_ener *bandE, int *pulses, int shortBlocks, int fold, int *tf_res, int resynth, int total_bits, void *ec, int LM)
 {
-   int i, remaining_bits, balance;
+   int i, balance;
+   celt_int32 remaining_bits;
    const celt_int16 * restrict eBands = m->eBands;
    celt_norm * restrict norm;
    VARDECL(celt_norm, _norm);
@@ -883,6 +896,7 @@ void quant_all_bands(int encode, const CELTMode *m, int start, celt_norm *_X, ce
    int spread;
    celt_norm *lowband;
    int update_lowband = 1;
+   int C = _Y != NULL ? 2 : 1;
    SAVE_STACK;
 
    M = 1<<LM;
@@ -893,7 +907,7 @@ void quant_all_bands(int encode, const CELTMode *m, int start, celt_norm *_X, ce
 
    balance = 0;
    lowband = NULL;
-   for (i=start;i<m->nbEBands;i++)
+   for (i=start;i<end;i++)
    {
       int tell;
       int b;
@@ -909,20 +923,23 @@ void quant_all_bands(int encode, const CELTMode *m, int start, celt_norm *_X, ce
          Y = NULL;
       N = M*eBands[i+1]-M*eBands[i];
       if (encode)
-         tell = ec_enc_tell(ec, BITRES);
+         tell = ec_enc_tell((ec_enc*)ec, BITRES);
       else
          tell = ec_dec_tell((ec_dec*)ec, BITRES);
 
       if (i != start)
          balance -= tell;
       remaining_bits = (total_bits<<BITRES)-tell-1;
-      curr_balance = (m->nbEBands-i);
+      curr_balance = (end-i);
       if (curr_balance > 3)
          curr_balance = 3;
       curr_balance = balance / curr_balance;
       b = IMIN(remaining_bits+1,pulses[i]+curr_balance);
       if (b<0)
          b = 0;
+      /* Prevents ridiculous bit depths */
+      if (b > C*16*N<<BITRES)
+         b = C*16*N<<BITRES;
 
       if (M*eBands[i]-N >= M*eBands[start])
       {
@@ -931,12 +948,7 @@ void quant_all_bands(int encode, const CELTMode *m, int start, celt_norm *_X, ce
       } else
          lowband = NULL;
 
-      if (shortBlocks)
-      {
-         tf_change = tf_res[i] ? -1 : 2;
-      } else {
-         tf_change = tf_res[i] ? -2 : 0;
-      }
+      tf_change = tf_res[i];
       quant_band(encode, m, i, X, Y, N, b, spread, tf_change, lowband, resynth, ec, &remaining_bits, LM, norm+M*eBands[i], bandE, 0);
 
       balance += pulses[i] + tell;