From 60b8f5b80307c1ecb1070f602074c791444c52bc Mon Sep 17 00:00:00 2001 From: Jean-Marc Valin Date: Thu, 3 Apr 2025 14:45:24 -0400 Subject: [PATCH] Add SHL32_ovflw() and use it in IMDCT Prevents integer overflow UB in the headroom left shift when the signal blows up on bad bitstreams (if it triggers, the signal was already unusable anyway). --- celt/arch.h | 1 + celt/fixed_debug.h | 2 ++ celt/fixed_generic.h | 2 ++ celt/mdct.c | 4 ++-- 4 files changed, 7 insertions(+), 2 deletions(-) diff --git a/celt/arch.h b/celt/arch.h index e8321ac8..53611e13 100644 --- a/celt/arch.h +++ b/celt/arch.h @@ -330,6 +330,7 @@ static OPUS_INLINE int celt_isnan(float x) #define SUB32(a,b) ((a)-(b)) #define ADD32_ovflw(a,b) ((a)+(b)) #define SUB32_ovflw(a,b) ((a)-(b)) +#define SHL32_ovflw(a,shift) (a) #define PSHR32_ovflw(a,shift) (a) #define MULT16_16_16(a,b) ((a)*(b)) diff --git a/celt/fixed_debug.h b/celt/fixed_debug.h index 4b92a1e3..a879ec37 100644 --- a/celt/fixed_debug.h +++ b/celt/fixed_debug.h @@ -69,6 +69,8 @@ extern opus_int64 celt_mips; /* Avoid MSVC warning C4146: unary minus operator applied to unsigned type */ /** Negate 32-bit value, ignore any overflows */ #define NEG32_ovflw(a) (celt_mips+=2,(opus_val32)(0-(opus_uint32)(a))) +/** 32-bit shift left, ignoring overflows */ +#define SHL32_ovflw(a,shift) ((opus_int32)((opus_uint32)(a)<<(shift))) /** 32-bit arithmetic shift right with rounding-to-nearest, ignoring overflows */ #define PSHR32_ovflw(a,shift) (SHR32(ADD32_ovflw(a, (EXTEND32(1)<<(shift)>>1)),shift)) diff --git a/celt/fixed_generic.h b/celt/fixed_generic.h index 3de7bd6a..a4834df5 100644 --- a/celt/fixed_generic.h +++ b/celt/fixed_generic.h @@ -147,6 +147,8 @@ /* Avoid MSVC warning C4146: unary minus operator applied to unsigned type */ /** Negate 32-bit value, ignore any overflows */ #define NEG32_ovflw(a) ((opus_val32)(0-(opus_uint32)(a))) +/** 32-bit shift left, ignoring overflows */ +#define SHL32_ovflw(a,shift) SHL32(a,shift) /** 32-bit arithmetic shift right with rounding-to-nearest, ignoring overflows */ #define PSHR32_ovflw(a,shift) (SHR32(ADD32_ovflw(a, (EXTEND32(1)<<(shift)>>1)),shift)) diff --git a/celt/mdct.c b/celt/mdct.c index ba6b22eb..b03e8a21 100644 --- a/celt/mdct.c +++ b/celt/mdct.c @@ -297,8 +297,8 @@ void clt_mdct_backward_c(const mdct_lookup *l, kiss_fft_scalar *in, kiss_fft_sca kiss_fft_scalar yr, yi; opus_val32 x1, x2; rev = *bitrev++; - x1 = SHL32(*xp1, IMDCT_HEADROOM); - x2 = SHL32(*xp2, IMDCT_HEADROOM); + x1 = SHL32_ovflw(*xp1, IMDCT_HEADROOM); + x2 = SHL32_ovflw(*xp2, IMDCT_HEADROOM); yr = ADD32_ovflw(S_MUL(x2, t[i]), S_MUL(x1, t[N4+i])); yi = SUB32_ovflw(S_MUL(x1, t[i]), S_MUL(x2, t[N4+i])); /* We swap real and imag because we use an FFT instead of an IFFT. */