Reduces decoder stack usage by only storing one channel of denormalized MDCT
authorJean-Marc Valin <jmvalin@jmvalin.ca>
Mon, 6 Jan 2014 13:58:38 +0000 (08:58 -0500)
committerJean-Marc Valin <jmvalin@jmvalin.ca>
Mon, 6 Jan 2014 13:58:38 +0000 (08:58 -0500)
celt/bands.c
celt/bands.h
celt/celt.h
celt/celt_decoder.c
celt/celt_encoder.c

index 3308bc4..7e1d97e 100644 (file)
@@ -194,76 +194,73 @@ void normalise_bands(const CELTMode *m, const celt_sig * OPUS_RESTRICT freq, cel
 /* De-normalise the energy to produce the synthesis from the unit-energy bands */
 void denormalise_bands(const CELTMode *m, const celt_norm * OPUS_RESTRICT X,
       celt_sig * OPUS_RESTRICT freq, const opus_val16 *bandLogE, int start,
-      int end, int C, int M, int downsample, int silence)
+      int end, int M, int downsample, int silence)
 {
-   int i, c, N;
+   int i, N;
    int bound;
+   celt_sig * OPUS_RESTRICT f;
+   const celt_norm * OPUS_RESTRICT x;
    const opus_int16 *eBands = m->eBands;
    N = M*m->shortMdctSize;
    bound = M*eBands[end];
    if (downsample!=1)
       bound = IMIN(bound, N/downsample);
-   celt_assert2(C<=2, "denormalise_bands() not implemented for >2 channels");
    if (silence)
    {
       bound = 0;
       start = end = 0;
    }
-   c=0; do {
-      celt_sig * OPUS_RESTRICT f;
-      const celt_norm * OPUS_RESTRICT x;
-      f = freq+c*N;
-      x = X+c*N+M*eBands[start];
-      for (i=0;i<M*eBands[start];i++)
-         *f++ = 0;
-      for (i=start;i<end;i++)
-      {
-         int j, band_end;
-         opus_val16 g;
-         opus_val16 lg;
+   f = freq;
+   x = X+M*eBands[start];
+   for (i=0;i<M*eBands[start];i++)
+      *f++ = 0;
+   for (i=start;i<end;i++)
+   {
+      int j, band_end;
+      opus_val16 g;
+      opus_val16 lg;
 #ifdef FIXED_POINT
-         int shift;
+      int shift;
 #endif
-         j=M*eBands[i];
-         band_end = M*eBands[i+1];
-         lg = ADD16(bandLogE[i+c*m->nbEBands], SHL16((opus_val16)eMeans[i],6));
+      j=M*eBands[i];
+      band_end = M*eBands[i+1];
+      lg = ADD16(bandLogE[i], SHL16((opus_val16)eMeans[i],6));
 #ifndef FIXED_POINT
-         g = celt_exp2(lg);
+      g = celt_exp2(lg);
 #else
-         /* Handle the integer part of the log energy */
-         shift = 16-(lg>>DB_SHIFT);
-         if (shift>31)
-         {
-            shift=0;
-            g=0;
-         } else {
-            /* Handle the fractional part. */
-            g = celt_exp2_frac(lg&((1<<DB_SHIFT)-1));
-         }
-         /* Handle extreme gains with negative shift. */
-         if (shift<0)
-         {
-            /* For shift < -2 we'd be likely to overflow, so we're capping
+      /* Handle the integer part of the log energy */
+      shift = 16-(lg>>DB_SHIFT);
+      if (shift>31)
+      {
+         shift=0;
+         g=0;
+      } else {
+         /* Handle the fractional part. */
+         g = celt_exp2_frac(lg&((1<<DB_SHIFT)-1));
+      }
+      /* Handle extreme gains with negative shift. */
+      if (shift<0)
+      {
+         /* For shift < -2 we'd be likely to overflow, so we're capping
                the gain here. This shouldn't happen unless the bitstream is
                already corrupted. */
-            if (shift < -2)
-            {
-               g = 32767;
-               shift = -2;
-            }
-            do {
-               *f++ = SHL32(MULT16_16(*x++, g), -shift);
-            } while (++j<band_end);
-         } else
+         if (shift < -2)
+         {
+            g = 32767;
+            shift = -2;
+         }
+         do {
+            *f++ = SHL32(MULT16_16(*x++, g), -shift);
+         } while (++j<band_end);
+      } else
 #endif
          /* Be careful of the fixed-point "else" just above when changing this code */
          do {
             *f++ = SHR32(MULT16_16(*x++, g), shift);
          } while (++j<band_end);
-      }
-      celt_assert(start <= end);
-      OPUS_CLEAR(&freq[c*N+bound], N-bound);
-   } while (++c<C);
+   }
+   celt_assert(start <= end);
+   OPUS_CLEAR(&freq[bound], N-bound);
 }
 
 /* This prevents energy collapse for transients with multiple short MDCTs */
index 924f0a1..69901b1 100644 (file)
@@ -60,7 +60,7 @@ void normalise_bands(const CELTMode *m, const celt_sig * OPUS_RESTRICT freq, cel
  */
 void denormalise_bands(const CELTMode *m, const celt_norm * OPUS_RESTRICT X,
       celt_sig * OPUS_RESTRICT freq, const opus_val16 *bandE, int start,
-      int end, int C, int M, int downsample, int silence);
+      int end, int M, int downsample, int silence);
 
 #define SPREAD_NONE       (0)
 #define SPREAD_LIGHT      (1)
index 5deea1f..dcc6c7e 100644 (file)
@@ -205,10 +205,10 @@ void comb_filter(opus_val32 *y, opus_val32 *x, int T0, int T1, int N,
 void init_caps(const CELTMode *m,int *cap,int LM,int C);
 
 #ifdef RESYNTH
-void deemphasis(celt_sig *in[], opus_val16 *pcm, int N, int C, int downsample, const opus_val16 *coef, celt_sig *mem, celt_sig * OPUS_RESTRICT scratch);
-
-void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig *X,
-      celt_sig * OPUS_RESTRICT out_mem[], int C, int LM);
+void deemphasis(celt_sig *in[], opus_val16 *pcm, int N, int C, int downsample, const opus_val16 *coef, celt_sig *mem);
+void celt_synthesis(const CELTMode *mode, celt_norm *X, celt_sig * out_syn[],
+      opus_val16 *oldBandE, int start, int effEnd, int C, int CC, int isTransient,
+      int LM, int downsample, int silence);
 #endif
 
 #ifdef __cplusplus
index eefda85..7b79094 100644 (file)
@@ -190,13 +190,17 @@ static OPUS_INLINE opus_val16 SIG2WORD16(celt_sig x)
 #ifndef RESYNTH
 static
 #endif
-void deemphasis(celt_sig *in[], opus_val16 *pcm, int N, int C, int downsample, const opus_val16 *coef, celt_sig *mem, celt_sig * OPUS_RESTRICT scratch)
+void deemphasis(celt_sig *in[], opus_val16 *pcm, int N, int C, int downsample, const opus_val16 *coef,
+      celt_sig *mem)
 {
    int c;
    int Nd;
    int apply_downsampling=0;
    opus_val16 coef0;
+   VARDECL(celt_sig, scratch);
+   SAVE_STACK;
 
+   ALLOC(scratch, N, celt_sig);
    coef0 = coef[0];
    Nd = N/downsample;
    c=0; do {
@@ -250,6 +254,7 @@ void deemphasis(celt_sig *in[], opus_val16 *pcm, int N, int C, int downsample, c
             y[j*C] = SCALEOUT(SIG2WORD16(scratch[j*downsample]));
       }
    } while (++c<C);
+   RESTORE_STACK;
 }
 
 /** Compute the IMDCT and apply window for all sub-frames and
@@ -258,9 +263,9 @@ void deemphasis(celt_sig *in[], opus_val16 *pcm, int N, int C, int downsample, c
 static
 #endif
 void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig *X,
-      celt_sig * OPUS_RESTRICT out_mem[], int C, int LM)
+      celt_sig * OPUS_RESTRICT out_mem, int LM)
 {
-   int b, c;
+   int b;
    int B;
    int N;
    int shift;
@@ -276,11 +281,67 @@ void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig *X,
       N = mode->shortMdctSize<<LM;
       shift = mode->maxLM-LM;
    }
-   c=0; do {
-      /* IMDCT on the interleaved the sub-frames, overlap-add is performed by the IMDCT */
-      for (b=0;b<B;b++)
-         clt_mdct_backward(&mode->mdct, &X[b+c*N*B], out_mem[c]+N*b, mode->window, overlap, shift, B);
-   } while (++c<C);
+   /* IMDCT on the interleaved the sub-frames, overlap-add is performed by the IMDCT */
+   for (b=0;b<B;b++)
+      clt_mdct_backward(&mode->mdct, &X[b], out_mem+N*b, mode->window, overlap, shift, B);
+}
+
+#ifndef RESYNTH
+static
+#endif
+void celt_synthesis(const CELTMode *mode, celt_norm *X, celt_sig * out_syn[],
+      opus_val16 *oldBandE, int start, int effEnd, int C, int CC, int isTransient,
+      int LM, int downsample, int silence)
+{
+   int c, i;
+   int M, N;
+   int nbEBands;
+   int shortBlocks;
+   int overlap;
+   VARDECL(celt_sig, freq);
+   SAVE_STACK;
+
+   overlap = mode->overlap;
+   nbEBands = mode->nbEBands;
+   N = mode->shortMdctSize<<LM;
+   ALLOC(freq, N, celt_sig); /**< Interleaved signal MDCTs */
+   M = 1<<LM;
+   shortBlocks = isTransient ? M : 0;
+
+   if (CC==2&&C==1)
+   {
+      /* Copying a mono streams to two channels */
+      celt_sig *freq2;
+      denormalise_bands(mode, X, freq, oldBandE, start, effEnd, M,
+            downsample, silence);
+      /* Store a temporary copy in the output buffer because the IMDCT destroys its input. */
+      freq2 = out_syn[1]+overlap/2;
+      OPUS_COPY(freq2, freq, N);
+      compute_inv_mdcts(mode, shortBlocks, freq2, out_syn[0], LM);
+      compute_inv_mdcts(mode, shortBlocks, freq, out_syn[1], LM);
+   } else if (CC==1&&C==2)
+   {
+      /* Downmixing a stereo stream to mono */
+      celt_sig *freq2;
+      freq2 = out_syn[0]+overlap/2;
+      denormalise_bands(mode, X, freq, oldBandE, start, effEnd, M,
+            downsample, silence);
+      /* Use the output buffer as temp array before downmixing. */
+      denormalise_bands(mode, X+N, freq2, oldBandE+nbEBands, start, effEnd, M,
+            downsample, silence);
+      for (i=0;i<N;i++)
+         freq[i] = HALF32(ADD32(freq[i],freq2[i]));
+      /* Compute inverse MDCTs */
+      compute_inv_mdcts(mode, shortBlocks, freq, out_syn[0], LM);
+   } else {
+      /* Normal case (mono or stereo) */
+      c=0; do {
+         denormalise_bands(mode, X+c*N, freq, oldBandE+c*nbEBands, start, effEnd, M,
+               downsample, silence);
+         compute_inv_mdcts(mode, shortBlocks, freq, out_syn[c], LM);
+      } while (++c<CC);
+   }
+   RESTORE_STACK;
 }
 
 static void tf_decode(int start, int end, int isTransient, int *tf_res, int LM, ec_dec *dec)
@@ -347,7 +408,6 @@ static void celt_decode_lost(CELTDecoder * OPUS_RESTRICT st, opus_val16 * OPUS_R
    int loss_count;
    int noise_based;
    const opus_int16 *eBands;
-   VARDECL(celt_sig, scratch);
    SAVE_STACK;
 
    mode = st->mode;
@@ -369,11 +429,9 @@ static void celt_decode_lost(CELTDecoder * OPUS_RESTRICT st, opus_val16 * OPUS_R
    start = st->start;
    downsample = st->downsample;
    noise_based = loss_count >= 5 || start != 0;
-   ALLOC(scratch, noise_based?N*C:N, celt_sig);
    if (noise_based)
    {
       /* Noise-based PLC/CNG */
-      celt_sig *freq;
       VARDECL(celt_norm, X);
       opus_uint32 seed;
       opus_val16 *plcLogE;
@@ -385,7 +443,6 @@ static void celt_decode_lost(CELTDecoder * OPUS_RESTRICT st, opus_val16 * OPUS_R
 
       /* Share the interleaved signal MDCT coefficient buffer with the
          deemphasis scratch buffer. */
-      freq = scratch;
       ALLOC(X, C*N, celt_norm);   /**< Interleaved normalised MDCTs */
 
       if (loss_count >= 5)
@@ -421,14 +478,12 @@ static void celt_decode_lost(CELTDecoder * OPUS_RESTRICT st, opus_val16 * OPUS_R
       }
       st->rng = seed;
 
-      denormalise_bands(mode, X, freq, plcLogE, start, effEnd, C, 1<<LM,
-            downsample, 0);
-
       c=0; do {
          OPUS_MOVE(decode_mem[c], decode_mem[c]+N,
                DECODE_BUFFER_SIZE-N+(overlap>>1));
       } while (++c<C);
-      compute_inv_mdcts(mode, 0, freq, out_syn, C, LM);
+
+      celt_synthesis(mode, X, out_syn, plcLogE, start, effEnd, C, C, 0, LM, st->downsample, 0);
    } else {
       /* Pitch-based PLC */
       const opus_val16 *window;
@@ -639,7 +694,7 @@ static void celt_decode_lost(CELTDecoder * OPUS_RESTRICT st, opus_val16 * OPUS_R
    }
 
    deemphasis(out_syn, pcm, N, C, downsample,
-         mode->preemph, st->preemph_memD, scratch);
+         mode->preemph, st->preemph_memD);
 
    st->loss_count = loss_count+1;
 
@@ -909,18 +964,12 @@ int celt_decode_with_ec(CELTDecoder * OPUS_RESTRICT st, const unsigned char *dat
       anti_collapse(mode, X, collapse_masks, LM, C, N,
             start, end, oldBandE, oldLogE, oldLogE2, pulses, st->rng);
 
-   ALLOC(freq, IMAX(CC,C)*N, celt_sig); /**< Interleaved signal MDCTs */
-
    if (silence)
    {
       for (i=0;i<C*nbEBands;i++)
          oldBandE[i] = -QCONST16(28.f,DB_SHIFT);
    }
 
-   /* Synthesis */
-   denormalise_bands(mode, X, freq, oldBandE, start, effEnd, C, M,
-         st->downsample, silence);
-
    c=0; do {
       OPUS_MOVE(decode_mem[c], decode_mem[c]+N, DECODE_BUFFER_SIZE-N+overlap/2);
    } while (++c<CC);
@@ -929,16 +978,7 @@ int celt_decode_with_ec(CELTDecoder * OPUS_RESTRICT st, const unsigned char *dat
       out_syn[c] = decode_mem[c]+DECODE_BUFFER_SIZE-N;
    } while (++c<CC);
 
-   if (CC==2&&C==1)
-      OPUS_COPY(freq+N, freq, N);
-   if (CC==1&&C==2)
-   {
-      for (i=0;i<N;i++)
-         freq[i] = HALF32(ADD32(freq[i],freq[N+i]));
-   }
-
-   /* Compute inverse MDCTs */
-   compute_inv_mdcts(mode, shortBlocks, freq, out_syn, CC, LM);
+   celt_synthesis(mode, X, out_syn, oldBandE, start, effEnd, C, CC, isTransient, LM, st->downsample, silence);
 
    c=0; do {
       st->postfilter_period=IMAX(st->postfilter_period, COMBFILTER_MINPERIOD);
@@ -995,7 +1035,7 @@ int celt_decode_with_ec(CELTDecoder * OPUS_RESTRICT st, const unsigned char *dat
    st->rng = dec->rng;
 
    /* We reuse freq[] as scratch space for the de-emphasis */
-   deemphasis(out_syn, pcm, N, CC, st->downsample, mode->preemph, st->preemph_memD, freq);
+   deemphasis(out_syn, pcm, N, CC, st->downsample, mode->preemph, st->preemph_memD);
    st->loss_count = 0;
    RESTORE_STACK;
    if (ec_tell(dec) > 8*len)
index 2841594..4cd8c4a 100644 (file)
@@ -1973,25 +1973,15 @@ int celt_encode_with_ec(CELTEncoder * OPUS_RESTRICT st, const opus_val16 * pcm,
                start, end, oldBandE, oldLogE, oldLogE2, pulses, st->rng);
       }
 
-      /* Synthesis */
-      denormalise_bands(mode, X, freq, oldBandE, start, effEnd, C, M,
-            st->upsample, silence);
-
       c=0; do {
          OPUS_MOVE(st->syn_mem[c], st->syn_mem[c]+N, 2*MAX_PERIOD-N+overlap/2);
       } while (++c<CC);
 
-      if (CC==2&&C==1)
-      {
-         for (i=0;i<N;i++)
-            freq[N+i] = freq[i];
-      }
-
       c=0; do {
          out_mem[c] = st->syn_mem[c]+2*MAX_PERIOD-N;
       } while (++c<CC);
 
-      compute_inv_mdcts(mode, shortBlocks, freq, out_mem, CC, LM);
+      celt_synthesis(mode, X, out_mem, oldBandE, start, effEnd, C, CC, isTransient, LM, st->upsample, silence);
 
       c=0; do {
          st->prefilter_period=IMAX(st->prefilter_period, COMBFILTER_MINPERIOD);
@@ -2006,7 +1996,7 @@ int celt_encode_with_ec(CELTEncoder * OPUS_RESTRICT st, const opus_val16 * pcm,
       } while (++c<CC);
 
       /* We reuse freq[] as scratch space for the de-emphasis */
-      deemphasis(out_mem, (opus_val16*)pcm, N, CC, st->upsample, mode->preemph, st->preemph_memD, freq);
+      deemphasis(out_mem, (opus_val16*)pcm, N, CC, st->upsample, mode->preemph, st->preemph_memD);
       st->prefilter_period_old = st->prefilter_period;
       st->prefilter_gain_old = st->prefilter_gain;
       st->prefilter_tapset_old = st->prefilter_tapset;