Note some more platforms where float-approx is tested, fix a bug in the prediction...
[opus.git] / libcelt / celt.c
index 222f2fe..5d32fdc 100644 (file)
@@ -81,8 +81,10 @@ struct CELTEncoder {
    int overlap;
    int channels;
    
-   int pitch_enabled;
-   int pitch_available;
+   int pitch_enabled;       /* Complexity level is allowed to use pitch */
+   int pitch_permitted;     /*  Use of the LTP is permitted by the user */
+   int pitch_available;     /*  Amount of pitch buffer available */
+   int force_intra;
    int delayedIntra;
    celt_word16_t tonal_average;
    int fold_decision;
@@ -139,7 +141,9 @@ CELTEncoder *celt_encoder_create(const CELTMode *mode)
 
    st->VBR_rate = 0;
    st->pitch_enabled = 1;
+   st->pitch_permitted = 1;
    st->pitch_available = 1;
+   st->force_intra  = 0;
    st->delayedIntra = 1;
    st->tonal_average = QCONST16(1.,8);
    st->fold_decision = 1;
@@ -308,7 +312,7 @@ static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * re
          mdct_forward(lookup, x, tmp, mode->window, overlap);
          /* Interleaving the sub-frames */
          for (j=0;j<N;j++)
-            out[C*j+c] = tmp[j];
+            out[j+c*N] = tmp[j];
       }
       RESTORE_STACK;
    } else {
@@ -332,7 +336,7 @@ static void compute_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t * re
             mdct_forward(lookup, x, tmp, mode->window, overlap);
             /* Interleaving the sub-frames */
             for (j=0;j<N;j++)
-               out[C*(j*B+b)+c] = tmp[j];
+               out[(j*B+b)+c*N*B] = tmp[j];
          }
       }
       RESTORE_STACK;
@@ -363,7 +367,7 @@ static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t
          ALLOC(tmp, N, celt_word32_t);
          /* De-interleaving the sub-frames */
          for (j=0;j<N;j++)
-            tmp[j] = X[C*j+c];
+            tmp[j] = X[j+c*N];
          /* Prevents problems from the imdct doing the overlap-add */
          CELT_MEMSET(x+N4, 0, N);
          mdct_backward(lookup, tmp, x, mode->window, overlap);
@@ -393,7 +397,7 @@ static void compute_inv_mdcts(const CELTMode *mode, int shortBlocks, celt_sig_t
          {
             /* De-interleaving the sub-frames */
             for (j=0;j<N2;j++)
-               tmp[j] = X[C*(j*B+b)+c];
+               tmp[j] = X[(j*B+b)+c*N2*B];
             mdct_backward(lookup, tmp, x+N4+N2*b, mode->window, overlap);
          }
          if (transient_shift > 0)
@@ -505,11 +509,13 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
    VARDECL(celt_norm_t, X);
    VARDECL(celt_norm_t, P);
    VARDECL(celt_ener_t, bandE);
+   VARDECL(celt_word16_t, bandLogE);
    VARDECL(celt_pgain_t, gains);
    VARDECL(int, fine_quant);
    VARDECL(celt_word16_t, error);
    VARDECL(int, pulses);
    VARDECL(int, offsets);
+   VARDECL(int, fine_priority);
 #ifdef EXP_PSY
    VARDECL(celt_word32_t, mask);
    VARDECL(celt_word32_t, tonality);
@@ -560,52 +566,47 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
       }
    }
    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
-   
+
    /* Transient handling */
-   if (st->mode->nbShortMdcts > 1)
+   transient_time = -1;
+   transient_shift = 0;
+   shortBlocks = 0;
+
+   if (st->mode->nbShortMdcts > 1 && transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift))
    {
-      if (transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift))
-      {
 #ifndef FIXED_POINT
-         float gain_1;
+      float gain_1;
 #endif
-         /* Apply the inverse shaping window */
-         if (transient_shift)
-         {
+      /* Apply the inverse shaping window */
+      if (transient_shift)
+      {
 #ifdef FIXED_POINT
-            for (c=0;c<C;c++)
-               for (i=0;i<16;i++)
-                  in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
-            for (c=0;c<C;c++)
-               for (i=transient_time;i<N+st->overlap;i++)
-                  in[C*i+c] = SHR32(in[C*i+c], transient_shift);
+         for (c=0;c<C;c++)
+            for (i=0;i<16;i++)
+               in[C*(transient_time+i-16)+c] = MULT16_32_Q15(EXTRACT16(SHR32(celt_rcp(Q15ONE+MULT16_16(transientWindow[i],((1<<transient_shift)-1))),1)), in[C*(transient_time+i-16)+c]);
+         for (c=0;c<C;c++)
+            for (i=transient_time;i<N+st->overlap;i++)
+               in[C*i+c] = SHR32(in[C*i+c], transient_shift);
 #else
-            for (c=0;c<C;c++)
-               for (i=0;i<16;i++)
-                  in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
-            gain_1 = 1./(1<<transient_shift);
-            for (c=0;c<C;c++)
-               for (i=transient_time;i<N+st->overlap;i++)
-                  in[C*i+c] *= gain_1;
+         for (c=0;c<C;c++)
+            for (i=0;i<16;i++)
+               in[C*(transient_time+i-16)+c] /= 1+transientWindow[i]*((1<<transient_shift)-1);
+         gain_1 = 1./(1<<transient_shift);
+         for (c=0;c<C;c++)
+            for (i=transient_time;i<N+st->overlap;i++)
+               in[C*i+c] *= gain_1;
 #endif
-         }
-         shortBlocks = 1;
-         has_fold = 1;
-      } else {
-         transient_time = -1;
-         transient_shift = 0;
-         shortBlocks = 0;
       }
-   } else {
-      transient_time = -1;
-      transient_shift = 0;
-      shortBlocks = 0;
+      shortBlocks = 1;
+      has_fold = 1;
    }
 
    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
+   ALLOC(bandLogE,st->mode->nbEBands*C, celt_word16_t);
    /* Compute MDCTs */
    compute_mdcts(st->mode, shortBlocks, in, freq);
+
    if (shortBlocks && !transient_shift) 
    {
       celt_word32_t sum[4]={1,1,1,1};
@@ -615,7 +616,7 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
          m=0;
          do {
             celt_word32_t tmp=0;
-            for (i=m*C+c;i<N;i+=C*st->mode->nbShortMdcts)
+            for (i=m+c*N;i<(c+1)*N;i+=st->mode->nbShortMdcts)
                tmp += ABS32(freq[i]);
             sum[m++] += tmp;
          } while (m<st->mode->nbShortMdcts);
@@ -638,7 +639,7 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
       {
          for (c=0;c<C;c++)
             for (m=mdct_weight_pos+1;m<st->mode->nbShortMdcts;m++)
-               for (i=m*C+c;i<N;i+=C*st->mode->nbShortMdcts)
+               for (i=m+c*N;i<(c+1)*N;i+=st->mode->nbShortMdcts)
                   freq[i] = SHR32(freq[i],mdct_weight_shift);
       }
 #else
@@ -658,31 +659,27 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
       {
          for (c=0;c<C;c++)
             for (m=mdct_weight_pos+1;m<st->mode->nbShortMdcts;m++)
-               for (i=m*C+c;i<N;i+=C*st->mode->nbShortMdcts)
+               for (i=m+c*N;i<(c+1)*N;i+=st->mode->nbShortMdcts)
                   freq[i] = (1./(1<<mdct_weight_shift))*freq[i];
       }
 #endif
-      /*printf ("%f\n", short_ratio);*/
-      /*if (short_ratio < 1)
-         short_ratio = 1;
-      short_ratio = 1<<(int)floor(.5+log2(short_ratio));
-      if (short_ratio>4)
-         short_ratio = 4;*/
-   }/* else if (transient_shift)
-      printf ("8\n");
-      else printf ("1\n");*/
+   }
 
    compute_band_energies(st->mode, freq, bandE);
+   for (i=0;i<st->mode->nbEBands*C;i++)
+      bandLogE[i] = amp2Log(bandE[i]);
 
-   intra_ener = st->delayedIntra;
-   if (intra_decision(bandE, st->oldBandE, st->mode->nbEBands) || shortBlocks)
+   /* Don't use intra energy when we're operating at low bit-rate */
+   intra_ener = st->force_intra || (st->delayedIntra && nbCompressedBytes > st->mode->nbEBands);
+   if (shortBlocks || intra_decision(bandLogE, st->oldBandE, st->mode->nbEBands))
       st->delayedIntra = 1;
    else
       st->delayedIntra = 0;
+
    /* Pitch analysis: we do it early to save on the peak stack space */
    /* Don't use pitch if there isn't enough data available yet, 
       or if we're using shortBlocks */
-   has_pitch = st->pitch_enabled && (st->pitch_available >= MAX_PERIOD) && (!shortBlocks) && !intra_ener;
+   has_pitch = st->pitch_enabled && st->pitch_permitted && (st->pitch_available >= MAX_PERIOD) && (!shortBlocks) && !intra_ener;
 #ifdef EXP_PSY
    ALLOC(tonality, MAX_PERIOD/4, celt_word16_t);
    {
@@ -754,9 +751,14 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
       compute_band_energies(st->mode, freq, bandEp);
       normalise_bands(st->mode, freq, P, bandEp);
       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
-      /* Check if we can safely use the pitch (i.e. effective gain 
-         isn't too high) */
       curr_power = bandE[0]+bandE[1]+bandE[2];
+      if (C>1)
+      {
+         pitch_power += bandEp[0+st->mode->nbEBands]+bandEp[1+st->mode->nbEBands]+bandEp[2+st->mode->nbEBands];
+         curr_power += bandE[0+st->mode->nbEBands]+bandE[1+st->mode->nbEBands]+bandE[2+st->mode->nbEBands];
+      }
+      /* Check if we can safely use the pitch (i.e. effective gain 
+      isn't too high) */
       if ((MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
       {
          /* Pitch prediction */
@@ -808,7 +810,7 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
 
    /* Bit allocation */
    ALLOC(error, C*st->mode->nbEBands, celt_word16_t);
-   coarse_needed = quant_coarse_energy(st->mode, bandE, st->oldBandE, nbCompressedBytes*8/3, intra_ener, st->mode->prob, error, &enc);
+   coarse_needed = quant_coarse_energy(st->mode, bandLogE, st->oldBandE, nbCompressedBytes*8/3, intra_ener, st->mode->prob, error, &enc);
    coarse_needed = ((coarse_needed*3-1)>>3)+1;
 
    /* Variable bitrate */
@@ -834,6 +836,7 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
    }
 
    ALLOC(offsets, st->mode->nbEBands, int);
+   ALLOC(fine_priority, st->mode->nbEBands, int);
 
    for (i=0;i<st->mode->nbEBands;i++)
       offsets[i] = 0;
@@ -841,7 +844,7 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
    if (has_pitch)
       bits -= st->mode->nbPBands;
 #ifndef STDIN_TUNING
-   compute_allocation(st->mode, offsets, bits, pulses, fine_quant);
+   compute_allocation(st->mode, offsets, bits, pulses, fine_quant, fine_priority);
 #endif
 
    quant_fine_energy(st->mode, bandE, st->oldBandE, error, fine_quant, &enc);
@@ -853,6 +856,9 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
    else
       quant_bands_stereo(st->mode, X, P, NULL, has_pitch, gains, bandE, pulses, shortBlocks, has_fold, nbCompressedBytes*8, &enc);
 #endif
+
+   quant_energy_finalise(st->mode, bandE, st->oldBandE, error, fine_quant, fine_priority, nbCompressedBytes*8-ec_enc_tell(&enc, 0), &enc);
+
    /* Re-synthesis of the coded audio if required */
    if (st->pitch_available>0 || optional_synthesis!=NULL)
    {
@@ -870,7 +876,7 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
          int m;
          for (c=0;c<C;c++)
             for (m=mdct_weight_pos+1;m<st->mode->nbShortMdcts;m++)
-               for (i=m*C+c;i<N;i+=C*st->mode->nbShortMdcts)
+               for (i=m+c*N;i<(c+1)*N;i+=st->mode->nbShortMdcts)
 #ifdef FIXED_POINT
                   freq[i] = SHL32(freq[i], mdct_weight_shift);
 #else
@@ -895,28 +901,8 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
       }
    }
 
-   /* Finishing the stream with a 0101... pattern so that the 
-      decoder can check is everything's right */
-   {
-      int val = 0;
-      while (ec_enc_tell(&enc, 0) < nbCompressedBytes*8)
-      {
-         ec_enc_uint(&enc, val, 2);
-         val = 1-val;
-      }
-   }
    ec_enc_done(&enc);
-   {
-      /*unsigned char *data;*/
-      int nbBytes = ec_byte_bytes(&buf);
-      if (nbBytes > nbCompressedBytes)
-      {
-         celt_warning_int ("got too many bytes:", nbBytes);
-         RESTORE_STACK;
-         return CELT_INTERNAL_ERROR;
-      }
-   }
-
+   
    RESTORE_STACK;
    return nbCompressedBytes;
 }
@@ -1021,15 +1007,22 @@ int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
          }   
       }
       break;
-      case CELT_SET_LTP_REQUEST:
+      case CELT_SET_PREDICTION_REQUEST:
       {
          int value = va_arg(ap, celt_int32_t);
-         if (value<0 || value>1 || (value==1 && st->pitch_available==0))
+         if (value<0 || value>2)
             goto bad_arg;
          if (value==0)
-            st->pitch_enabled = 0;
-         else
-            st->pitch_enabled = 1;
+         {
+            st->force_intra   = 1;
+            st->pitch_permitted = 0;
+         } else if (value==1) {
+            st->force_intra   = 0;
+            st->pitch_permitted = 0;
+         } else {
+            st->force_intra   = 0;
+            st->pitch_permitted = 1;
+         }   
       }
       break;
       case CELT_SET_VBR_RATE_REQUEST:
@@ -1237,7 +1230,7 @@ static void celt_decode_lost(CELTDecoder * restrict st, celt_word16_t * restrict
    while (offset+len >= MAX_PERIOD)
       offset -= pitch_index;
    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
-   for (i=0;i<N;i++)
+   for (i=0;i<C*N;i++)
       freq[i] = ADD32(EPSILON, MULT16_32_Q15(QCONST16(.9f,15),freq[i]));
 #endif
    
@@ -1283,6 +1276,7 @@ int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int
    VARDECL(int, fine_quant);
    VARDECL(int, pulses);
    VARDECL(int, offsets);
+   VARDECL(int, fine_priority);
 
    int shortBlocks;
    int intra_ener;
@@ -1357,6 +1351,7 @@ int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int
    
    ALLOC(pulses, st->mode->nbEBands, int);
    ALLOC(offsets, st->mode->nbEBands, int);
+   ALLOC(fine_priority, st->mode->nbEBands, int);
 
    for (i=0;i<st->mode->nbEBands;i++)
       offsets[i] = 0;
@@ -1364,7 +1359,7 @@ int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int
    bits = len*8 - ec_dec_tell(&dec, 0) - 1;
    if (has_pitch)
       bits -= st->mode->nbPBands;
-   compute_allocation(st->mode, offsets, bits, pulses, fine_quant);
+   compute_allocation(st->mode, offsets, bits, pulses, fine_quant, fine_priority);
    /*bits = ec_dec_tell(&dec, 0);
    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(&dec, 0)-bits))/C);*/
    
@@ -1393,6 +1388,8 @@ int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int
    else
       unquant_bands_stereo(st->mode, X, P, has_pitch, gains, bandE, pulses, shortBlocks, has_fold, len*8, &dec);
 #endif
+   unquant_energy_finalise(st->mode, bandE, st->oldBandE, fine_quant, fine_priority, len*8-ec_dec_tell(&dec, 0), &dec);
+   
    /* Synthesis */
    denormalise_bands(st->mode, X, freq, bandE);
 
@@ -1403,7 +1400,7 @@ int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int
       int m;
       for (c=0;c<C;c++)
          for (m=mdct_weight_pos+1;m<st->mode->nbShortMdcts;m++)
-            for (i=m*C+c;i<N;i+=C*st->mode->nbShortMdcts)
+            for (i=m+c*N;i<(c+1)*N;i+=st->mode->nbShortMdcts)
 #ifdef FIXED_POINT
                freq[i] = SHL32(freq[i], mdct_weight_shift);
 #else
@@ -1425,20 +1422,6 @@ int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int
       }
    }
 
-   {
-      unsigned int val = 0;
-      while (ec_dec_tell(&dec, 0) < len*8)
-      {
-         if (ec_dec_uint(&dec, 2) != val)
-         {
-            celt_warning("decode error");
-            RESTORE_STACK;
-            return CELT_CORRUPTED_DATA;
-         }
-         val = 1-val;
-      }
-   }
-
    RESTORE_STACK;
    return 0;
    /*printf ("\n");*/