Saving the state rather than re-compute the best option
[opus.git] / celt / bands.c
index 5088ee8..bab1684 100644 (file)
@@ -1346,6 +1346,9 @@ void quant_all_bands(int encode, const CELTMode *m, int start, int end,
    VARDECL(celt_norm, _lowband_scratch);
    VARDECL(celt_norm, X_save);
    VARDECL(celt_norm, Y_save);
+   VARDECL(celt_norm, X_save2);
+   VARDECL(celt_norm, Y_save2);
+   VARDECL(celt_norm, norm_save2);
    int resynth_alloc;
    celt_norm *lowband_scratch;
    int B;
@@ -1386,6 +1389,9 @@ void quant_all_bands(int encode, const CELTMode *m, int start, int end,
       lowband_scratch = X_+M*eBands[m->nbEBands-1];
    ALLOC(X_save, resynth_alloc, celt_norm);
    ALLOC(Y_save, resynth_alloc, celt_norm);
+   ALLOC(X_save2, resynth_alloc, celt_norm);
+   ALLOC(Y_save2, resynth_alloc, celt_norm);
+   ALLOC(norm_save2, resynth_alloc, celt_norm);
 
    lowband_offset = 0;
    ctx.bandE = bandE;
@@ -1497,10 +1503,13 @@ void quant_all_bands(int encode, const CELTMode *m, int start, int end,
          {
             if (theta_rdo)
             {
-               ec_ctx ec_save;
-               struct band_ctx ctx_save;
+               ec_ctx ec_save, ec_save2;
+               struct band_ctx ctx_save, ctx_save2;
                opus_val32 dist0, dist1;
-               unsigned cm;
+               unsigned cm, cm2;
+               int nstart_bytes, nend_bytes, save_bytes;
+               unsigned char *bytes_buf;
+               unsigned char bytes_save[1275];
                /* Make a copy. */
                cm = x_cm|y_cm;
                ec_save = *ec;
@@ -1513,6 +1522,21 @@ void quant_all_bands(int encode, const CELTMode *m, int start, int end,
                      effective_lowband != -1 ? norm+effective_lowband : NULL, LM,
                      last?NULL:norm+M*eBands[i]-norm_offset, lowband_scratch, cm);
                dist0 = celt_inner_prod(X_save, X, N, arch) + celt_inner_prod(Y_save, Y, N, arch);
+
+               /* Save first result. */
+               cm2 = x_cm;
+               ec_save2 = *ec;
+               ctx_save2 = ctx;
+               OPUS_COPY(X_save2, X, N);
+               OPUS_COPY(Y_save2, Y, N);
+               if (!last)
+                  OPUS_COPY(norm_save2, norm+M*eBands[i]-norm_offset, N);
+               nstart_bytes = ec_save.offs;
+               nend_bytes = ec_save.storage;
+               bytes_buf = ec_save.buf+nstart_bytes;
+               save_bytes = nend_bytes-nstart_bytes;
+               OPUS_COPY(bytes_save, bytes_buf, save_bytes);
+
                /* Restore */
                *ec = ec_save;
                ctx = ctx_save;
@@ -1524,16 +1548,16 @@ void quant_all_bands(int encode, const CELTMode *m, int start, int end,
                      effective_lowband != -1 ? norm+effective_lowband : NULL, LM,
                      last?NULL:norm+M*eBands[i]-norm_offset, lowband_scratch, cm);
                dist1 = celt_inner_prod(X_save, X, N, arch) + celt_inner_prod(Y_save, Y, N, arch);
-               /* Restore */
-               *ec = ec_save;
-               ctx = ctx_save;
-               OPUS_COPY(X, X_save, N);
-               OPUS_COPY(Y, Y_save, N);
-               /* Encode with best choice. */
-               ctx.theta_round = dist0 >= dist1 ? -1 : 1;
-               x_cm = quant_band_stereo(&ctx, X, Y, N, b, B,
-                     effective_lowband != -1 ? norm+effective_lowband : NULL, LM,
-                     last?NULL:norm+M*eBands[i]-norm_offset, lowband_scratch, cm);
+               if (dist0 >= dist1) {
+                  x_cm = cm2;
+                  *ec = ec_save2;
+                  ctx = ctx_save2;
+                  OPUS_COPY(X, X_save2, N);
+                  OPUS_COPY(Y, Y_save2, N);
+                  if (!last)
+                     OPUS_COPY(norm+M*eBands[i]-norm_offset, norm_save2, N);
+                  OPUS_COPY(bytes_buf, bytes_save, save_bytes);
+               }
             } else {
                ctx.theta_round = 0;
                x_cm = quant_band_stereo(&ctx, X, Y, N, b, B,