Making it easier to Torben to develop his new PLC code
[opus.git] / libcelt / celt.c
index bdf96e1..05af26c 100644 (file)
@@ -1,5 +1,5 @@
 /* (C) 2007-2008 Jean-Marc Valin, CSIRO
-*/
+   (C) 2008 Gregory Maxwell */
 /*
    Redistribution and use in source and binary forms, with or without
    modification, are permitted provided that the following conditions
@@ -44,7 +44,6 @@
 #include "bands.h"
 #include "modes.h"
 #include "entcode.h"
-#include "quant_pitch.h"
 #include "quant_bands.h"
 #include "psy.h"
 #include "rate.h"
@@ -77,6 +76,7 @@ struct CELTEncoder {
    int channels;
    
    int pitch_enabled;
+   int pitch_available;
 
    celt_word16_t * restrict preemph_memE; /* Input is 16-bit, so why bother with 32 */
    celt_sig_t    * restrict preemph_memD;
@@ -109,6 +109,7 @@ CELTEncoder *celt_encoder_create(const CELTMode *mode)
    st->overlap = mode->overlap;
 
    st->pitch_enabled = 1;
+   st->pitch_available = 1;
 
    st->in_mem = celt_alloc(st->overlap*C*sizeof(celt_sig_t));
    st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
@@ -377,7 +378,6 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
    int has_fold=1;
    ec_byte_buffer buf;
    ec_enc         enc;
-   celt_word32_t curr_power, pitch_power=0;
    VARDECL(celt_sig_t, in);
    VARDECL(celt_sig_t, freq);
    VARDECL(celt_norm_t, X);
@@ -404,6 +404,9 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
    if (check_mode(st->mode) != CELT_OK)
       return CELT_INVALID_MODE;
 
+   if (nbCompressedBytes<0)
+     return CELT_BAD_ARG; 
+
    /* The memset is important for now in case the encoder doesn't fill up all the bytes */
    CELT_MEMSET(compressed, 0, nbCompressedBytes);
    ec_byte_writeinit_buffer(&buf, compressed, nbCompressedBytes);
@@ -438,8 +441,8 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
 #ifndef FIXED_POINT
          float gain_1;
 #endif
-         ec_enc_bits(&enc, 0, 1); //Pitch off
-         ec_enc_bits(&enc, 1, 1); //Transient on
+         ec_enc_bits(&enc, 0, 1); /*Pitch off */
+         ec_enc_bits(&enc, 1, 1); /*Transient on */
          ec_enc_bits(&enc, transient_shift, 2);
          if (transient_shift)
             ec_enc_uint(&enc, transient_time, N+st->overlap);
@@ -476,6 +479,8 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
    }
 
    /* Pitch analysis: we do it early to save on the peak stack space */
+   /* Don't use pitch if there isn't enough data available yet, or if we're using shortBlocks */
+   has_pitch = st->pitch_enabled && (st->pitch_available >= MAX_PERIOD) && (!shortBlocks);
 #ifdef EXP_PSY
    ALLOC(tonality, MAX_PERIOD/4, celt_word16_t);
    {
@@ -485,7 +490,7 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
       compute_tonality(st->mode, X, st->psy_mem, MAX_PERIOD, tonality, MAX_PERIOD/4);
    }
 #else
-   if (st->pitch_enabled && !shortBlocks)
+   if (has_pitch)
    {
       find_spectral_pitch(st->mode, st->mode->fft, &st->mode->psy, in, st->out_mem, st->mode->window, NULL, 2*N-2*N4, MAX_PERIOD-(2*N-2*N4), &pitch_index);
    }
@@ -540,8 +545,9 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
 #endif
 
    /* Compute MDCTs of the pitch part */
-   if (st->pitch_enabled && !shortBlocks)
+   if (has_pitch)
    {
+      celt_word32_t curr_power, pitch_power=0;
       /* Normalise the pitch vector as well (discard the energies) */
       VARDECL(celt_ener_t, bandEp);
       
@@ -550,31 +556,22 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
       compute_band_energies(st->mode, freq, bandEp);
       normalise_bands(st->mode, freq, P, bandEp);
       pitch_power = bandEp[0]+bandEp[1]+bandEp[2];
-   }
-
-   /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
-   curr_power = bandE[0]+bandE[1]+bandE[2];
-   if (st->pitch_enabled && !shortBlocks && (MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
-   {
-      int id;
-
-      /* Pitch prediction */
-      compute_pitch_gain(st->mode, X, P, gains);
-      id = quant_pitch(gains, st->mode->nbPBands, &enc);
-      if (id != -1)
-         has_pitch = 1;
-      else
-         has_pitch = 0;
-      ec_enc_bits(&enc, has_pitch, 1); /* Pitch flag */
-      if (has_pitch)
+      /* Check if we can safely use the pitch (i.e. effective gain isn't too high) */
+      curr_power = bandE[0]+bandE[1]+bandE[2];
+      if ((MULT16_32_Q15(QCONST16(.1f, 15),curr_power) + QCONST32(10.f,ENER_SHIFT) < pitch_power))
       {
-         ec_enc_bits(&enc, has_fold, 1); /* Folding flag */
-         ec_enc_bits(&enc, id, 7);
-         ec_enc_uint(&enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
-      } else if (st->mode->nbShortMdcts > 1) {
-         ec_enc_bits(&enc, 0, 1); /* Transient off */
-         has_fold = 1;
+         /* Pitch prediction */
+         has_pitch = compute_pitch_gain(st->mode, X, P, gains);
+      } else {
+         has_pitch = 0;
       }
+   }
+   
+   if (has_pitch) 
+   {  
+      ec_enc_bits(&enc, has_pitch, 1); /* Pitch flag */
+      ec_enc_bits(&enc, has_fold, 1); /* Folding flag */
+      ec_enc_uint(&enc, pitch_index, MAX_PERIOD-(2*N-2*N4));
    } else {
       if (!shortBlocks)
       {
@@ -618,22 +615,26 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
    for (i=0;i<st->mode->nbEBands;i++)
       offsets[i] = 0;
    bits = nbCompressedBytes*8 - ec_enc_tell(&enc, 0) - 1;
+   if (has_pitch)
+      bits -= st->mode->nbPBands;
 #ifndef STDIN_TUNING
    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
 #endif
 
    quant_fine_energy(st->mode, bandE, st->oldBandE, error, fine_quant, &enc);
 
-   pitch_quant_bands(st->mode, P, gains);
-
    /* Residual quantisation */
-   quant_bands(st->mode, X, P, NULL, bandE, stereo_mode, pulses, shortBlocks, has_fold, nbCompressedBytes*8, &enc);
+   if (C==1)
+      quant_bands(st->mode, X, P, NULL, has_pitch, gains, bandE, pulses, shortBlocks, has_fold, nbCompressedBytes*8, &enc);
+   else
+      quant_bands_stereo(st->mode, X, P, NULL, has_pitch, gains, bandE, pulses, shortBlocks, has_fold, nbCompressedBytes*8, &enc);
 
    /* Re-synthesis of the coded audio if required */
-   if (st->pitch_enabled || optional_synthesis!=NULL)
+   if (st->pitch_available>0 || optional_synthesis!=NULL)
    {
-      if (C==2)
-         renormalise_bands(st->mode, X);
+      if (st->pitch_available>0 && st->pitch_available<MAX_PERIOD)
+        st->pitch_available+=st->frame_size;
+
       /* Synthesis */
       denormalise_bands(st->mode, X, freq, bandE);
       
@@ -656,10 +657,11 @@ int celt_encode_float(CELTEncoder * restrict st, const celt_sig_t * pcm, celt_si
          }
       }
    }
-   /*fprintf (stderr, "remaining bits after encode = %d\n", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));*/
-   /*if (ec_enc_tell(&st->enc, 0) < nbCompressedBytes*8 - 7)
-      celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&st->enc, 0));*/
-   /*printf ("%d\n", ec_enc_tell(&st->enc, 0)-8*nbCompressedBytes);*/
+
+   /*fprintf (stderr, "remaining bits after encode = %d\n", nbCompressedBytes*8-ec_enc_tell(&enc, 0));*/
+   /*if (ec_enc_tell(&enc, 0) < nbCompressedBytes*8 - 7)
+      celt_warning_int ("many unused bits: ", nbCompressedBytes*8-ec_enc_tell(&enc, 0));*/
+
    /* Finishing the stream with a 0101... pattern so that the decoder can check is everything's right */
    {
       int val = 0;
@@ -701,8 +703,7 @@ int celt_encode_float(CELTEncoder * restrict st, const float * pcm, float * opti
 
    if (optional_synthesis != NULL) {
      ret=celt_encode(st,in,in,compressed,nbCompressedBytes);
-   /*Converts backwards for inplace operation*/
-      for (j=0;j=C*N;j++)
+      for (j=0;j<C*N;j++)
          optional_synthesis[j]=in[j]*(1/32768.);
    } else {
      ret=celt_encode(st,in,NULL,compressed,nbCompressedBytes);
@@ -748,7 +749,22 @@ int celt_encoder_ctl(CELTEncoder * restrict st, int request, ...)
          int value = va_arg(ap, int);
          if (value<0 || value>10)
             goto bad_arg;
-         if (value<=2)
+         if (value<=2) {
+            st->pitch_enabled = 0; 
+            st->pitch_available = 0;
+         } else {
+              st->pitch_enabled = 1;
+              if (st->pitch_available<1)
+                st->pitch_available = 1;
+         }   
+      }
+      break;
+      case CELT_SET_LTP_REQUEST:
+      {
+         int value = va_arg(ap, int);
+         if (value<0 || value>1 || (value==1 && st->pitch_available==0))
+            goto bad_arg;
+         if (value==0)
             st->pitch_enabled = 0;
          else
             st->pitch_enabled = 1;
@@ -772,7 +788,11 @@ bad_request:
 /*                                DECODER                                   */
 /*                                                                          */
 /****************************************************************************/
-
+#ifdef NEW_PLC
+#define DECODE_BUFFER_SIZE 2048
+#else
+#define DECODE_BUFFER_SIZE MAX_PERIOD
+#endif
 
 /** Decoder state 
  @brief Decoder state
@@ -789,6 +809,7 @@ struct CELTDecoder {
    celt_sig_t * restrict preemph_memD;
 
    celt_sig_t *out_mem;
+   celt_sig_t *decode_mem;
 
    celt_word16_t *oldBandE;
    
@@ -812,7 +833,8 @@ CELTDecoder *celt_decoder_create(const CELTMode *mode)
    st->block_size = N;
    st->overlap = mode->overlap;
 
-   st->out_mem = celt_alloc((MAX_PERIOD+st->overlap)*C*sizeof(celt_sig_t));
+   st->decode_mem = celt_alloc((DECODE_BUFFER_SIZE+st->overlap)*C*sizeof(celt_sig_t));
+   st->out_mem = st->decode_mem+DECODE_BUFFER_SIZE-MAX_PERIOD;
    
    st->oldBandE = (celt_word16_t*)celt_alloc(C*mode->nbEBands*sizeof(celt_word16_t));
 
@@ -833,7 +855,7 @@ void celt_decoder_destroy(CELTDecoder *st)
       return;
 
 
-   celt_free(st->out_mem);
+   celt_free(st->decode_mem);
    
    celt_free(st->oldBandE);
    
@@ -844,6 +866,9 @@ void celt_decoder_destroy(CELTDecoder *st)
 
 /** Handles lost packets by just copying past data with the same offset as the last
     pitch period */
+#ifdef NEW_PLC
+#include "plc.c"
+#else
 static void celt_decode_lost(CELTDecoder * restrict st, celt_word16_t * restrict pcm)
 {
    int c, N;
@@ -893,12 +918,13 @@ static void celt_decode_lost(CELTDecoder * restrict st, celt_word16_t * restrict
    }
    RESTORE_STACK;
 }
+#endif
 
 #ifdef FIXED_POINT
-int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
+int celt_decode(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16_t * restrict pcm)
 {
 #else
-int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, celt_sig_t * restrict pcm)
+int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int len, celt_sig_t * restrict pcm)
 {
 #endif
    int i, c, N, N4;
@@ -946,6 +972,10 @@ int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, c
       RESTORE_STACK;
       return 0;
    }
+   if (len<0) {
+     RESTORE_STACK;
+     return CELT_BAD_ARG;
+   }
    
    ec_byte_readinit(&buf,data,len);
    ec_dec_init(&dec,&buf);
@@ -973,16 +1003,12 @@ int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, c
       transient_time = -1;
       transient_shift = 0;
    }
-   /* Get the pitch gains */
    
-   /* Get the pitch index */
    if (has_pitch)
    {
-      has_pitch = unquant_pitch(gains, st->mode->nbPBands, &dec);
       pitch_index = ec_dec_uint(&dec, MAX_PERIOD-(2*N-2*N4));
       st->last_pitch_index = pitch_index;
    } else {
-      /* FIXME: We could be more intelligent here and just not compute the MDCT */
       pitch_index = 0;
       for (i=0;i<st->mode->nbPBands;i++)
          gains[i] = 0;
@@ -1001,6 +1027,8 @@ int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, c
       offsets[i] = 0;
 
    bits = len*8 - ec_dec_tell(&dec, 0) - 1;
+   if (has_pitch)
+      bits -= st->mode->nbPBands;
    compute_allocation(st->mode, offsets, stereo_mode, bits, pulses, fine_quant);
    /*bits = ec_dec_tell(&dec, 0);
    compute_fine_allocation(st->mode, fine_quant, (20*C+len*8/5-(ec_dec_tell(&dec, 0)-bits))/C);*/
@@ -1017,26 +1045,23 @@ int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, c
       ALLOC(bandEp, st->mode->nbEBands*C, celt_ener_t);
       compute_band_energies(st->mode, freq, bandEp);
       normalise_bands(st->mode, freq, P, bandEp);
+      /* Apply pitch gains */
    } else {
       for (i=0;i<C*N;i++)
          P[i] = 0;
    }
 
-   /* Apply pitch gains */
-   pitch_quant_bands(st->mode, P, gains);
-
    /* Decode fixed codebook and merge with pitch */
-   unquant_bands(st->mode, X, P, bandE, stereo_mode, pulses, shortBlocks, has_fold, len*8, &dec);
+   if (C==1)
+      unquant_bands(st->mode, X, P, has_pitch, gains, bandE, pulses, shortBlocks, has_fold, len*8, &dec);
+   else
+      unquant_bands_stereo(st->mode, X, P, has_pitch, gains, bandE, pulses, shortBlocks, has_fold, len*8, &dec);
 
-   if (C==2)
-   {
-      renormalise_bands(st->mode, X);
-   }
    /* Synthesis */
    denormalise_bands(st->mode, X, freq, bandE);
 
 
-   CELT_MOVE(st->out_mem, st->out_mem+C*N, C*(MAX_PERIOD+st->overlap-N));
+   CELT_MOVE(st->decode_mem, st->decode_mem+C*N, C*(DECODE_BUFFER_SIZE+st->overlap-N));
    /* Compute inverse MDCTs */
    compute_inv_mdcts(st->mode, shortBlocks, freq, transient_time, transient_shift, st->out_mem);
 
@@ -1073,7 +1098,7 @@ int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, c
 
 #ifdef FIXED_POINT
 #ifndef DISABLE_FLOAT_API
-int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, float * restrict pcm)
+int celt_decode_float(CELTDecoder * restrict st, const unsigned char *data, int len, float * restrict pcm)
 {
    int j, ret;
    const int C = CHANNELS(st->mode);
@@ -1091,7 +1116,7 @@ int celt_decode_float(CELTDecoder * restrict st, unsigned char *data, int len, f
 }
 #endif /*DISABLE_FLOAT_API*/
 #else
-int celt_decode(CELTDecoder * restrict st, unsigned char *data, int len, celt_int16_t * restrict pcm)
+int celt_decode(CELTDecoder * restrict st, const unsigned char *data, int len, celt_int16_t * restrict pcm)
 {
    int j, ret;
    VARDECL(celt_sig_t, out);