Add a SET_LTP call to allow on the fly activation/deactivation of the long term
[opus.git] / libcelt / celt.c
index c8b288b..9ebd8df 100644 (file)
@@ -77,9 +77,7 @@ struct CELTEncoder {
    int channels;
    
    int pitch_enabled;
-   
-   ec_byte_buffer buf;
-   ec_enc         enc;
+   int pitch_available;
 
    celt_word16_t * restrict preemph_memE; /* Input is 16-bit, so why bother with 32 */
    celt_sig_t    * restrict preemph_memD;
@@ -112,9 +110,7 @@ CELTEncoder *celt_encoder_create(const CELTMode *mode)
    st->overlap = mode->overlap;
 
    st->pitch_enabled = 1;
-   
-   ec_byte_writeinit(&st->buf);
-   ec_enc_init(&st->enc,&st->buf);
+   st->pitch_available = 1;
 
    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
@@ -142,8 +138,6 @@ void celt_encoder_destroy(CELTEncoder *st)
    if (check_mode(st->mode) != CELT_OK)
       return;
 
-   ec_byte_writeclear(&st->buf);
-
    celt_free(st->in_mem);
    celt_free(st->out_mem);
    
@@ -380,9 +374,12 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
 #endif
    int i, c, N, N4;
    int has_pitch;
+   int id;
    int pitch_index;
    int bits;
-   celt_word32_t curr_power, pitch_power=0;
+   int has_fold=1;
+   ec_byte_buffer buf;
+   ec_enc         enc;
    VARDECL(celt_sig_t, in);
    VARDECL(celt_sig_t, freq);
    VARDECL(celt_norm_t, X);
@@ -396,6 +393,9 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
    VARDECL(int, offsets);
 #ifdef EXP_PSY
    VARDECL(celt_word32_t, mask);
+   VARDECL(celt_word32_t, tonality);
+   VARDECL(celt_word32_t, bandM);
+   VARDECL(celt_ener_t, bandN);
 #endif
    int shortBlocks=0;
    int transient_time;
@@ -406,6 +406,11 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
    if (check_mode(st->mode) != CELT_OK)
       return CELT_INVALID_MODE;
 
+   /* The memset is important for now in case the encoder doesn't fill up all the bytes */
+   CELT_MEMSET(compressed, 0, nbCompressedBytes);
+   ec_byte_writeinit_buffer(&buf, compressed, nbCompressedBytes);
+   ec_enc_init(&enc,&buf);
+
    N = st->block_size;
    N4 = (N-st->overlap)>>1;
    ALLOC(in, 2*C*N-2*C*N4, celt_sig_t);
@@ -427,6 +432,7 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
    }
    CELT_COPY(st->in_mem, in+C*(2*N-2*N4-st->overlap), C*st->overlap);
    
+   /* Transient handling */
    if (st->mode->nbShortMdcts > 1)
    {
       if (transient_analysis(in, N+st->overlap, C, &transient_time, &transient_shift))
@@ -434,11 +440,12 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
 #ifndef FIXED_POINT
          float gain_1;
 #endif
-         ec_enc_bits(&st->enc, 0, 1); //Pitch off
-         ec_enc_bits(&st->enc, 1, 1); //Transient on
-         ec_enc_bits(&st->enc, transient_shift, 2);
+         ec_enc_bits(&enc, 0, 1); //Pitch off
+         ec_enc_bits(&enc, 1, 1); //Transient on
+         ec_enc_bits(&enc, transient_shift, 2);
          if (transient_shift)
-            ec_enc_uint(&st->enc, transient_time, N+st->overlap);
+            ec_enc_uint(&enc, transient_time, N+st->overlap);
+         /* Apply the inverse shaping window */
          if (transient_shift)
          {
 #ifdef FIXED_POINT
@@ -469,57 +476,77 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
       transient_shift = 0;
       shortBlocks = 0;
    }
-   /* Pitch analysis: we do it early to save on the peak stack space */
-   if (st->pitch_enabled && !shortBlocks)
-      find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
 
+   /* Pitch analysis: we do it early to save on the peak stack space */
+   /* Don't use pitch if there isn't enough data available yet, or if we're using shortBlocks */
+   has_pitch = st->pitch_enabled && (st->pitch_available >= MAX_PERIOD) && (!shortBlocks);
+#ifdef EXP_PSY
+   ALLOC(tonality, MAX_PERIOD/4, celt_word16_t);
+   {
+      VARDECL(celt_word16_t, X);
+      ALLOC(X, MAX_PERIOD/2, celt_word16_t);
+      find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, X, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
+      compute_tonality(st->mode, X, st->psy_mem, MAX_PERIOD, tonality, MAX_PERIOD/4);
+   }
+#else
+   if (has_pitch)
+   {
+      find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, NULL, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
+   }
+#endif
    ALLOC(freq, C*N, celt_sig_t); /**< Interleaved signal MDCTs */
    
-   /*for (i=0;i<(B+1)*C*N;i++) printf ("%f(%d) ", in[i], i); printf ("\n");*/
    /* Compute MDCTs */
    compute_mdcts(st->mode, shortBlocks, in, freq);
 
 #ifdef EXP_PSY
-   CELT_MOVE(st->psy_mem, st->out_mem+N, MAX_PERIOD+st->overlap-N);
-   for (i=0;i<N;i++)
-      st->psy_mem[MAX_PERIOD+st->overlap-N+i] = in[C*(st->overlap+i)];
-   for (c=1;c<C;c++)
-      for (i=0;i<N;i++)
-         st->psy_mem[MAX_PERIOD+st->overlap-N+i] += in[C*(st->overlap+i)+c];
-
    ALLOC(mask, N, celt_sig_t);
-   compute_mdct_masking(&st->psy, freq, st->psy_mem, mask, C*N);
-
-   /* Invert and stretch the mask to length of X 
-      For some reason, I get better results by using the sqrt instead,
-      although there's no valid reason to. Must investigate further */
-   for (i=0;i<C*N;i++)
-      mask[i] = 1/(.1+mask[i]);
+   compute_mdct_masking(&st->psy, freq, tonality, st->psy_mem, mask, C*N);
+   /*for (i=0;i<256;i++)
+      printf ("%f %f %f ", freq[i], tonality[i], mask[i]);
+   printf ("\n");*/
 #endif
-   
+
    /* Deferred allocation after find_spectral_pitch() to reduce the peak memory usage */
    ALLOC(X, C*N, celt_norm_t);         /**< Interleaved normalised MDCTs */
    ALLOC(P, C*N, celt_norm_t);         /**< Interleaved normalised pitch MDCTs*/
    ALLOC(bandE,st->mode->nbEBands*C, celt_ener_t);
    ALLOC(gains,st->mode->nbPBands, celt_pgain_t);
 
-   /*printf ("%f %f\n", curr_power, pitch_power);*/
-   /*int j;
-   for (j=0;j<B*N;j++)
-      printf ("%f ", X[j]);
-   for (j=0;j<B*N;j++)
-      printf ("%f ", P[j]);
-   printf ("\n");*/
 
    /* Band normalisation */
    compute_band_energies(st->mode, freq, bandE);
    normalise_bands(st->mode, freq, X, bandE);
-   /*for (i=0;i<st->mode->nbEBands;i++)printf("%f ", bandE[i]);printf("\n");*/
-   /*for (i=0;i<N*B*C;i++)printf("%f ", X[i]);printf("\n");*/
+
+#ifdef EXP_PSY
+   ALLOC(bandN,C*st->mode->nbEBands, celt_ener_t);
+   ALLOC(bandM,st->mode->nbEBands, celt_ener_t);
+   compute_noise_energies(st->mode, freq, tonality, bandN);
+
+   /*for (i=0;i<st->mode->nbEBands;i++)
+      printf ("%f ", (.1+bandN[i])/(.1+bandE[i]));
+   printf ("\n");*/
+   has_fold = 0;
+   for (i=st->mode->nbPBands;i<st->mode->nbEBands;i++)
+      if (bandN[i] < .4*bandE[i])
+         has_fold++;
+   /*printf ("%d\n", has_fold);*/
+   if (has_fold>=2)
+      has_fold = 0;
+   else
+      has_fold = 1;
+   for (i=0;i<N;i++)
+      mask[i] = sqrt(mask[i]);
+   compute_band_energies(st->mode, mask, bandM);
+   /*for (i=0;i<st->mode->nbEBands;i++)
+      printf ("%f %f ", bandE[i], bandM[i]);
+   printf ("\n");*/
+#endif
 
    /* Compute MDCTs of the pitch part */
-   if (st->pitch_enabled && !shortBlocks)
+   if (has_pitch)
    {
+      celt_word32_t curr_power, pitch_power=0;
       /* Normalise the pitch vector as well (discard the energies) */
       VARDECL(celt_ener_t, bandEp);
       
@@ -528,29 +555,35 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
       compute_band_energies(st->mode, freq, bandEp);
       normalise_bands(st->mode, freq, P, bandEp);
       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
+      /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
+      curr_power = bandE[0]+bandE[1]+bandE[2];
+      id=-1;
+      if ((MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
+      {
+         /* Pitch prediction */
+         compute_pitch_gain(st->mode, X, P, gains);
+         id = quant_pitch(gains, st->mode->nbPBands);
+      } 
+      if (id == -1)
+         has_pitch = 0;
    }
-   curr_power = bandE[0]+bandE[1]+bandE[2];
-   /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
-   if (st->pitch_enabled && !shortBlocks && (MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
-   {
-      /* Simulates intensity stereo */
-      /*for (i=30;i<N*B;i++)
-         X[i*C+1] = P[i*C+1] = 0;*/
-
-      /* Pitch prediction */
-      compute_pitch_gain(st->mode, X, P, gains);
-      has_pitch = quant_pitch(gains, st->mode->nbPBands, &st->enc);
-      if (has_pitch)
-         ec_enc_uint(&st->enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
-      else if (st->mode->nbShortMdcts > 1)
-         ec_enc_bits(&st->enc, 0, 1); //Transient off
+   
+   if (has_pitch) 
+   {  
+      unquant_pitch(id, gains, st->mode->nbPBands);
+      ec_enc_bits(&enc, has_pitch, 1); /* Pitch flag */
+      ec_enc_bits(&enc, has_fold, 1); /* Folding flag */
+      ec_enc_bits(&enc, id, 7);
+      ec_enc_uint(&enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
+      pitch_quant_bands(st->mode, P, gains);
    } else {
       if (!shortBlocks)
       {
-         ec_enc_bits(&st->enc, 0, 1); //Pitch off
+         ec_enc_bits(&enc, 0, 1); /* Pitch off */
          if (st->mode->nbShortMdcts > 1)
-           ec_enc_bits(&st->enc, 0, 1); //Transient off
+           ec_enc_bits(&enc, 0, 1); /* Transient off */
       }
+      has_fold = 1;
       /* No pitch, so we just pretend we found a gain of zero */
       for (i=0;i<st->mode->nbPBands;i++)
          gains[i] = 0;
@@ -574,8 +607,10 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
    ALLOC(fine_quant, st->mode->nbEBands, int);
    ALLOC(pulses, st->mode->nbEBands, int);
 #endif
+
+   /* Bit allocation */
    ALLOC(error, C*st->mode->nbEBands, celt_word16_t);
-   quant_coarse_energy(st->mode, bandE, st->oldBandE, nbCompressedBytes*8/3, st->mode->prob, error, &st->enc);
+   quant_coarse_energy(st->mode, bandE, st->oldBandE, nbCompressedBytes*8/3, st->mode->prob, error, &enc);
    
    ALLOC(offsets, st->mode->nbEBands, int);
    ALLOC(stereo_mode, st->mode->nbEBands, int);
@@ -583,28 +618,22 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
 
    for (i=0;i<st->mode->nbEBands;i++)
       offsets[i] = 0;
-   bits = nbCompressedBytes*8 - ec_enc_tell(&st->enc, 0) - 1;
+   bits = nbCompressedBytes*8 - ec_enc_tell(&enc, 0) - 1;
 #ifndef STDIN_TUNING
    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
 #endif
-   /*for (i=0;i<st->mode->nbEBands;i++)
-      printf("%d ", fine_quant[i]);
-   for (i=0;i<st->mode->nbEBands;i++)
-      printf("%d ", pulses[i]);
-   printf ("\n");*/
-   /*bits = ec_enc_tell(&st->enc, 0);
-   compute_fine_allocation(st->mode, fine_quant, (20*C+nbCompressedBytes*8/5-(ec_enc_tell(&st->enc, 0)-bits))/C);*/
-   quant_fine_energy(st->mode, bandE, st->oldBandE, error, fine_quant, &st->enc);
-
-   pitch_quant_bands(st->mode, P, gains);
 
-   /*for (i=0;i<B*N;i++) printf("%f ",P[i]);printf("\n");*/
+   quant_fine_energy(st->mode, bandE, st->oldBandE, error, fine_quant, &enc);
 
    /* Residual quantisation */
-   quant_bands(st->mode, X, P, NULL, bandE, stereo_mode, pulses, shortBlocks, nbCompressedBytes*8, &st->enc);
-   
-   if (st->pitch_enabled || optional_synthesis!=NULL)
+   quant_bands(st->mode, X, P, NULL, bandE, stereo_mode, pulses, shortBlocks, has_fold, nbCompressedBytes*8, &enc);
+
+   /* Re-synthesis of the coded audio if required */
+   if (st->pitch_available>0 || optional_synthesis!=NULL)
    {
+      if (st->pitch_available>0 && st->pitch_available<MAX_PERIOD)
+        st->pitch_available+=st->frame_size;
+
       if (C==2)
          renormalise_bands(st->mode, X);
       /* Synthesis */
@@ -636,32 +665,23 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
    {
       int val = 0;
-      while (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8)
+      while (ec_enc_tell(&enc, 0) < nbCompressedBytes*8)
       {
-         ec_enc_uint(&st->enc, val, 2);
+         ec_enc_uint(&enc, val, 2);
          val = 1-val;
       }
    }
-   ec_enc_done(&st->enc);
+   ec_enc_done(&enc);
    {
-      unsigned char *data;
-      int nbBytes = ec_byte_bytes(&st->buf);
+      /*unsigned char *data;*/
+      int nbBytes = ec_byte_bytes(&buf);
       if (nbBytes > nbCompressedBytes)
       {
          celt_warning_int ("got too many bytes:", nbBytes);
          RESTORE_STACK;
          return CELT_INTERNAL_ERROR;
       }
-      /*printf ("%d\n", *nbBytes);*/
-      data = ec_byte_get_buffer(&st->buf);
-      for (i=0;i<nbBytes;i++)
-         compressed[i] = data[i];
-      for (;i<nbCompressedBytes;i++)
-         compressed[i] = 0;
    }
-   /* Reset the packing for the next encoding */
-   ec_byte_reset(&st->buf);
-   ec_enc_init(&st->enc,&st->buf);
 
    RESTORE_STACK;
    return nbCompressedBytes;
@@ -683,8 +703,7 @@ int celt_encode_float(CELTEncoder * restrict st, const float * pcm, float * opti
 
    if (optional_synthesis != NULL) {
      ret=celt_encode(st,in,in,compressed,nbCompressedBytes);
-   /*Converts backwards for inplace operation*/
-      for (j=0;j=C*N;j++)
+      for (j=0;j<C*N;j++)
          optional_synthesis[j]=in[j]*(1/32768.);
    } else {
      ret=celt_encode(st,in,NULL,compressed,nbCompressedBytes);
@@ -730,7 +749,22 @@ int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
          int value = va_arg(ap, int);
          if (value<0 || value>10)
             goto bad_arg;
-         if (value<=2)
+         if (value<=2) {
+            st->pitch_enabled = 0; 
+            st->pitch_available = 0;
+         } else {
+              st->pitch_enabled = 1;
+              if (st->pitch_available<1)
+                st->pitch_available = 1;
+         }   
+      }
+      break;
+      case CELT_SET_LTP_REQUEST:
+      {
+         int value = va_arg(ap, int);
+         if (value<0 || value>1 || (value==1 && st->pitch_available==0))
+            goto bad_arg;
+         if (value==0)
             st->pitch_enabled = 0;
          else
             st->pitch_enabled = 1;
@@ -846,14 +880,14 @@ static void celt_decode_lost(CELTDecoder * restrict st, celt_word16_t * restrict
    compute_mdcts(st->mode, st->mode->window, st->out_mem+pitch_index*C, freq);
 
 #else
-   find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, len, MAX_PERIOD-len-100, &pitch_index);
+   find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, st->out_mem+MAX_PERIOD-len, st->out_mem, st->mode->window, NULL, len, MAX_PERIOD-len-100, &pitch_index);
    pitch_index = MAX_PERIOD-len-pitch_index;
    offset = MAX_PERIOD-pitch_index;
    while (offset+len >= MAX_PERIOD)
       offset -= pitch_index;
    compute_mdcts(st->mode, 0, st->out_mem+offset*C, freq);
    for (i=0;i<N;i++)
-      freq[i] = MULT16_32_Q15(QCONST16(.9f,15),freq[i]);
+      freq[i] = ADD32(EPSILON, MULT16_32_Q15(QCONST16(.9f,15),freq[i]));
 #endif
    
    
@@ -955,16 +989,16 @@ int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, c
       transient_time = -1;
       transient_shift = 0;
    }
-   /* Get the pitch gains */
    
-   /* Get the pitch index */
    if (has_pitch)
    {
-      has_pitch = unquant_pitch(gains, st->mode->nbPBands, &dec);
+      int id;
+      /* Get the pitch gains and index */
+      id = ec_dec_bits(&dec, 7);
+      unquant_pitch(id, gains, st->mode->nbPBands);
       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
       st->last_pitch_index = pitch_index;
    } else {
-      /* FIXME: We could be more intelligent here and just not compute the MDCT */
       pitch_index = 0;
       for (i=0;i<st->mode->nbPBands;i++)
          gains[i] = 0;
@@ -999,16 +1033,15 @@ int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, c
       ALLOC(bandEp, st->mode->nbEBands*C, celt_ener_t);
       compute_band_energies(st->mode, freq, bandEp);
       normalise_bands(st->mode, freq, P, bandEp);
+      /* Apply pitch gains */
+      pitch_quant_bands(st->mode, P, gains);
    } else {
       for (i=0;i<C*N;i++)
          P[i] = 0;
    }
 
-   /* Apply pitch gains */
-   pitch_quant_bands(st->mode, P, gains);
-
    /* Decode fixed codebook and merge with pitch */
-   unquant_bands(st->mode, X, P, bandE, stereo_mode, pulses, shortBlocks, len*8, &dec);
+   unquant_bands(st->mode, X, P, bandE, stereo_mode, pulses, shortBlocks, has_fold, len*8, &dec);
 
    if (C==2)
    {