Reorders some operations in anti-collapse to reuse values
[opus.git] / libcelt / bands.c
index 0951c54..132c9d9 100644 (file)
@@ -212,56 +212,62 @@ void denormalise_bands(const CELTMode *m, const celt_norm * restrict X, celt_sig
 }
 
 /* This prevents energy collapse for transients with multiple short MDCTs */
-void anti_collapse(const CELTMode *m, celt_norm *_X, int LM, int C, int size,
+void anti_collapse(const CELTMode *m, celt_norm *_X, unsigned char *collapse_masks, int LM, int C, int size,
       int start, int end, celt_word16 *logE, celt_word16 *prev1logE,
       celt_word16 *prev2logE, int *pulses, celt_uint32 seed)
 {
    int c, i, j, k;
-   c=0; do
+   for (i=start;i<end;i++)
    {
-      for (i=start;i<end;i++)
+      int N0;
+      celt_word16 thresh, sqrt_1;
+      int depth;
+#ifdef FIXED_POINT
+      int shift;
+#endif
+
+      N0 = m->eBands[i+1]-m->eBands[i];
+      depth = (1+(pulses[i]>>BITRES))/(m->eBands[i+1]-m->eBands[i]<<LM);
+
+#ifdef FIXED_POINT
+      thresh = MULT16_32_Q15(QCONST16(0.3f, 15), MIN32(32767,SHR32(celt_exp2(-SHL16(depth, 11)),1) ));
+      {
+         celt_word32 t;
+         t = N0<<LM;
+         shift = celt_ilog2(t)>>1;
+         t = SHL32(t, (7-shift)<<1);
+         sqrt_1 = celt_rsqrt_norm(t);
+      }
+#else
+      thresh = .3f*celt_exp2(-depth);
+      sqrt_1 = celt_rsqrt(N0<<LM);
+#endif
+
+      c=0; do
       {
          celt_norm *X;
-         int N0;
          celt_word16 Ediff;
          celt_word16 r;
-         celt_word16 thresh;
-         int depth;
-
-         N0 = m->eBands[i+1]-m->eBands[i];
          Ediff = logE[c*m->nbEBands+i]-MIN16(prev1logE[c*m->nbEBands+i],prev2logE[c*m->nbEBands+i]);
          Ediff = MAX16(0, Ediff);
-         depth = (1+(pulses[i]>>BITRES))/(m->eBands[i+1]-m->eBands[i]<<LM);
 
 #ifdef FIXED_POINT
-         thresh = MULT16_32_Q15(QCONST16(0.3f, 15), MIN32(32767,SHR32(celt_exp2(-SHL16(depth, 11)),1) ));
          if (Ediff < 16384)
             r = 2*MIN16(16383,SHR32(celt_exp2(-SHL16(Ediff, 11-DB_SHIFT)),1));
          else
             r = 0;
          r = SHR16(MIN16(thresh, r),1);
-         {
-            int shift;
-            celt_word32 t;
-            t = N0<<LM;
-            shift = celt_ilog2(t)>>1;
-            t = SHL32(t, (7-shift)<<1);
-            r = SHR32(MULT16_16_Q15(celt_rsqrt_norm(t), r),shift);
-         }
+         r = SHR32(MULT16_16_Q15(sqrt_1, r),shift);
 #else
-         thresh = .3f*celt_exp2(-depth);
          r = 2.f*celt_exp2(-Ediff);
          r = MIN16(thresh, r);
-         r = r*celt_rsqrt(N0<<LM);
+         r = r*sqrt_1;
 #endif
          X = _X+c*size+(m->eBands[i]<<LM);
          for (k=0;k<1<<LM;k++)
          {
-            celt_word32 sum=0;
             /* Detect collapse */
-            for (j=0;j<N0;j++)
-               sum += ABS16(X[(j<<LM)+k]);
-            if (sum<QCONST16(1e-4, 14))
+            if (!(collapse_masks[i*C+c]&1<<k))
             {
                /* Fill with noise */
                for (j=0;j<N0;j++)
@@ -273,8 +279,8 @@ void anti_collapse(const CELTMode *m, celt_norm *_X, int LM, int C, int size,
          }
          /* We just added some energy, so we need to renormalise */
          renormalise_vector(X, N0<<LM, Q15ONE);
-      }
-   } while (++c<C);
+      } while (++c<C);
+   }
 
 }
 
@@ -605,10 +611,10 @@ static int compute_qn(int N, int b, int offset, int stereo)
    the mono and stereo case. Even in the mono case, it can split the band
    in two and transmit the energy difference with the two half-bands. It
    can be called recursively so bands can end up being split in 8 parts. */
-static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_norm *Y,
+static unsigned quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_norm *Y,
       int N, int b, int spread, int B, int intensity, int tf_change, celt_norm *lowband, int resynth, void *ec,
       celt_int32 *remaining_bits, int LM, celt_norm *lowband_out, const celt_ener *bandE, int level,
-      celt_int32 *seed, celt_word16 gain, celt_norm *lowband_scratch, int fill)
+      celt_uint32 *seed, celt_word16 gain, celt_norm *lowband_scratch, int fill)
 {
    int q;
    int curr_bits;
@@ -623,6 +629,7 @@ static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_
    int inv = 0;
    celt_word16 mid=0, side=0;
    int longBlocks;
+   unsigned cm=0;
 
    longBlocks = B0==1;
 
@@ -656,7 +663,7 @@ static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_
       } while (++c<1+stereo);
       if (lowband_out)
          lowband_out[0] = SHR16(X[0],4);
-      return;
+      return 1;
    }
 
    if (!stereo && level == 0)
@@ -680,6 +687,7 @@ static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_
             haar1(X, N>>k, 1<<k);
          if (lowband)
             haar1(lowband, N>>k, 1<<k);
+         fill |= fill<<(1<<k);
       }
       B>>=recombine;
       N_B<<=recombine;
@@ -691,6 +699,7 @@ static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_
             haar1(X, N_B, B);
          if (lowband)
             haar1(lowband, N_B, B);
+         fill |= fill<<B;
          B <<= 1;
          N_B >>= 1;
          time_divide++;
@@ -718,6 +727,8 @@ static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_
          Y = X+N;
          split = 1;
          LM -= 1;
+         if (B==1)
+            fill |= fill<<1;
          B = (B+1)>>1;
       }
    }
@@ -853,11 +864,13 @@ static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_
       {
          imid = 32767;
          iside = 0;
+         fill &= (1<<B)-1;
          delta = -16384;
       } else if (itheta == 16384)
       {
          imid = 0;
          iside = 32767;
+         fill &= (1<<B)-1<<B;
          delta = 16384;
       } else {
          imid = bitexact_cos(itheta);
@@ -906,7 +919,9 @@ static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_
             }
          }
          sign = 1-2*sign;
-         quant_band(encode, m, i, x2, NULL, N, mbits, spread, B, intensity, tf_change, lowband, resynth, ec, remaining_bits, LM, lowband_out, NULL, level, seed, gain, lowband_scratch, fill);
+         cm = quant_band(encode, m, i, x2, NULL, N, mbits, spread, B, intensity, tf_change, lowband, resynth, ec, remaining_bits, LM, lowband_out, NULL, level, seed, gain, lowband_scratch, fill);
+         /* We don't split N=2 bands, so cm is either 1 or 0 (for a fold-collapse),
+             and there's no need to worry about mixing with the other channel. */
          y2[0] = -sign*x2[1];
          y2[1] = sign*x2[0];
          if (resynth)
@@ -955,12 +970,14 @@ static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_
 
          /* In stereo mode, we do not apply a scaling to the mid because we need the normalized
             mid for folding later */
-         quant_band(encode, m, i, X, NULL, N, mbits, spread, B, intensity, tf_change,
+         cm = quant_band(encode, m, i, X, NULL, N, mbits, spread, B, intensity, tf_change,
                lowband, resynth, ec, remaining_bits, LM, next_lowband_out1,
                NULL, next_level, seed, stereo ? Q15ONE : MULT16_16_P15(gain,mid), lowband_scratch, fill);
-         quant_band(encode, m, i, Y, NULL, N, sbits, spread, B, intensity, tf_change,
+         /* For a stereo split, the high bits of fill are always zero, so no
+             folding will be done to the side. */
+         cm |= quant_band(encode, m, i, Y, NULL, N, sbits, spread, B, intensity, tf_change,
                next_lowband2, resynth, ec, remaining_bits, LM, NULL,
-               NULL, next_level, seed, MULT16_16_P15(gain,side), NULL, fill && !stereo);
+               NULL, next_level, seed, MULT16_16_P15(gain,side), NULL, fill>>B)<<B;
       }
 
    } else {
@@ -984,9 +1001,9 @@ static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_
 
          /* Finally do the actual quantization */
          if (encode)
-            alg_quant(X, N, K, spread, B, lowband, resynth, (ec_enc*)ec, seed, gain);
+            cm = alg_quant(X, N, K, spread, B, lowband, resynth, (ec_enc*)ec, gain);
          else
-            alg_unquant(X, N, K, spread, B, lowband, (ec_dec*)ec, seed, gain);
+            cm = alg_unquant(X, N, K, spread, B, lowband, (ec_dec*)ec, gain);
       } else {
          /* If there's no pulse, fill the band anyway */
          int j;
@@ -1005,10 +1022,12 @@ static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_
                      *seed = lcg_rand(*seed);
                      X[j] = (celt_int32)(*seed)>>20;
                   }
+                  cm = (1<<B)-1;
                } else {
                   /* Folded spectrum */
                   for (j=0;j<N;j++)
                      X[j] = lowband[j];
+                  cm = fill;
                }
                renormalise_vector(X, N, gain);
             }
@@ -1022,7 +1041,10 @@ static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_
       if (stereo)
       {
          if (N!=2)
+         {
+            cm |= cm>>B;
             stereo_merge(X, Y, mid, N);
+         }
          if (inv)
          {
             int j;
@@ -1044,11 +1066,15 @@ static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_
          {
             B >>= 1;
             N_B <<= 1;
+            cm |= cm>>B;
             haar1(X, N_B, B);
          }
 
          for (k=0;k<recombine;k++)
+         {
+            cm |= cm<<(1<<k);
             haar1(X, N0>>k, 1<<k);
+         }
          B<<=recombine;
          N_B>>=recombine;
 
@@ -1063,12 +1089,13 @@ static void quant_band(int encode, const CELTMode *m, int i, celt_norm *X, celt_
          }
       }
    }
+   return cm;
 }
 
 void quant_all_bands(int encode, const CELTMode *m, int start, int end,
-      celt_norm *_X, celt_norm *_Y, const celt_ener *bandE, int *pulses,
+      celt_norm *_X, celt_norm *_Y, unsigned char *collapse_masks, const celt_ener *bandE, int *pulses,
       int shortBlocks, int spread, int dual_stereo, int intensity, int *tf_res, int resynth,
-      int total_bits, void *ec, int LM, int codedBands)
+      int total_bits, void *ec, int LM, int codedBands, ec_uint32 *seed)
 {
    int i;
    celt_int32 balance;
@@ -1079,7 +1106,6 @@ void quant_all_bands(int encode, const CELTMode *m, int start, int end,
    VARDECL(celt_norm, lowband_scratch);
    int B;
    int M;
-   celt_int32 seed;
    int lowband_offset;
    int update_lowband = 1;
    int C = _Y != NULL ? 2 : 1;
@@ -1092,12 +1118,8 @@ void quant_all_bands(int encode, const CELTMode *m, int start, int end,
    norm = _norm;
    norm2 = norm + M*eBands[m->nbEBands];
 
-   if (encode)
-      seed = ((ec_enc*)ec)->rng;
-   else
-      seed = ((ec_dec*)ec)->rng;
    balance = 0;
-   lowband_offset = -1;
+   lowband_offset = 0;
    for (i=start;i<end;i++)
    {
       celt_int32 tell;
@@ -1107,7 +1129,9 @@ void quant_all_bands(int encode, const CELTMode *m, int start, int end,
       int effective_lowband=-1;
       celt_norm * restrict X, * restrict Y;
       int tf_change=0;
-      
+      unsigned x_cm;
+      unsigned y_cm;
+
       X = _X+M*eBands[i];
       if (_Y!=NULL)
          Y = _Y+M*eBands[i];
@@ -1131,8 +1155,8 @@ void quant_all_bands(int encode, const CELTMode *m, int start, int end,
          b = 0;
       }
 
-      if (M*eBands[i]-N >= M*eBands[start] && (update_lowband || lowband_offset==-1))
-            lowband_offset = M*eBands[i];
+      if (M*eBands[i]-N >= M*eBands[start] && (update_lowband || lowband_offset==0))
+            lowband_offset = i;
 
       tf_change = tf_res[i];
       if (i>=m->effEBands)
@@ -1143,8 +1167,30 @@ void quant_all_bands(int encode, const CELTMode *m, int start, int end,
       }
 
       /* This ensures we never repeat spectral content within one band */
-      if (lowband_offset != -1)
-         effective_lowband = IMAX(M*eBands[start], lowband_offset-N);
+      if (lowband_offset != 0)
+         effective_lowband = IMAX(M*eBands[start], M*eBands[lowband_offset]-N);
+
+      /* Get a conservative estimate of the collapse_mask's for the bands we're
+          going to be folding from. */
+      if (lowband_offset != 0 && (spread!=SPREAD_AGGRESSIVE || B>1))
+      {
+         int fold_start;
+         int fold_end;
+         int fold_i;
+         fold_start = lowband_offset;
+         while(M*eBands[--fold_start] > effective_lowband);
+         fold_end = lowband_offset-1;
+         while(M*eBands[++fold_end] < effective_lowband+N);
+         x_cm = y_cm = 0;
+         fold_i = fold_start; do {
+           x_cm |= collapse_masks[fold_i*C+0];
+           y_cm |= collapse_masks[fold_i*C+1];
+         } while (++fold_i<fold_end);
+      }
+      /* Otherwise, we'll be using the LCG to fold, so all blocks will (almost
+          always) be non-zero.*/
+      else
+         x_cm = y_cm = (1<<B)-1;
 
       if (dual_stereo && i==intensity)
       {
@@ -1157,16 +1203,19 @@ void quant_all_bands(int encode, const CELTMode *m, int start, int end,
       }
       if (dual_stereo)
       {
-         quant_band(encode, m, i, X, NULL, N, b/2, spread, B, intensity, tf_change,
+         x_cm = quant_band(encode, m, i, X, NULL, N, b/2, spread, B, intensity, tf_change,
                effective_lowband != -1 ? norm+effective_lowband : NULL, resynth, ec, &remaining_bits, LM,
-               norm+M*eBands[i], bandE, 0, &seed, Q15ONE, lowband_scratch, 1);
-         quant_band(encode, m, i, Y, NULL, N, b/2, spread, B, intensity, tf_change,
+               norm+M*eBands[i], bandE, 0, seed, Q15ONE, lowband_scratch, x_cm);
+         y_cm = quant_band(encode, m, i, Y, NULL, N, b/2, spread, B, intensity, tf_change,
                effective_lowband != -1 ? norm2+effective_lowband : NULL, resynth, ec, &remaining_bits, LM,
-               norm2+M*eBands[i], bandE, 0, &seed, Q15ONE, lowband_scratch, 1);
+               norm2+M*eBands[i], bandE, 0, seed, Q15ONE, lowband_scratch, y_cm);
+         collapse_masks[i*2+0] = (unsigned char)(x_cm&(1<<B)-1);
+         collapse_masks[i*2+1] = (unsigned char)(y_cm&(1<<B)-1);
       } else {
-         quant_band(encode, m, i, X, Y, N, b, spread, B, intensity, tf_change,
+         x_cm = quant_band(encode, m, i, X, Y, N, b, spread, B, intensity, tf_change,
                effective_lowband != -1 ? norm+effective_lowband : NULL, resynth, ec, &remaining_bits, LM,
-               norm+M*eBands[i], bandE, 0, &seed, Q15ONE, lowband_scratch, 1);
+               norm+M*eBands[i], bandE, 0, seed, Q15ONE, lowband_scratch, x_cm|y_cm);
+         collapse_masks[i*C+1] = collapse_masks[i*C+0] = (unsigned char)(x_cm&(1<<B)-1);
       }
       balance += pulses[i] + tell;