Alternative fix for CVE-2022-4304

This is about a timing leak in the topmost limb
of the internal result of RSA_private_decrypt,
before the padding check.

There are in fact at least three bugs together that
caused the timing leak:

First and probably most important is the fact that
the blinding did not use the constant time code path
at all when the RSA object was used for a private
decrypt, due to the fact that the Montgomery context
rsa->_method_mod_n was not set up early enough in
rsa_ossl_private_decrypt, when BN_BLINDING_create_param
needed it, and that was persisted as blinding->m_ctx,
although the RSA object creates the Montgomery context
just a bit later.

Then the infamous bn_correct_top was used on the
secret value right after the blinding was removed.

And finally the function BN_bn2binpad did not use
the constant-time code path since the BN_FLG_CONSTTIME
was not set on the secret value.

In order to address the first problem, this patch
makes sure that the rsa->_method_mod_n is initialized
right before the blinding context.

And to fix the second problem, we add a new utility
function bn_correct_top_consttime, a const-time
variant of bn_correct_top.

Together with the fact, that BN_bn2binpad is already
constant time if the flag BN_FLG_CONSTTIME is set,
this should eliminate the timing oracle completely.

In addition the no-asm variant may also have
branches that depend on secret values, because the last
invocation of bn_sub_words in bn_from_montgomery_word
had branches when the function is compiled by certain
gcc compiler versions, due to the clumsy coding style.

So additionally this patch stream-lined the no-asm
C-code in order to avoid branches where possible and
improve the resulting code quality.

Reviewed-by: Paul Dale <pauli@openssl.org>
Reviewed-by: Tomas Mraz <tomas@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/20281)
This commit is contained in:
Bernd Edlinger 2023-02-13 17:46:41 +01:00 committed by Tomas Mraz
parent 4209ce68d8
commit f06ef1657a
6 changed files with 112 additions and 69 deletions

View File

@ -25,6 +25,17 @@ OpenSSL 3.2
### Changes between 3.1 and 3.2 [xx XXX xxxx]
* Reworked the Fix for the Timing Oracle in RSA Decryption ([CVE-2022-4304]).
The previous fix for this timing side channel turned out to cause
a severe 2-3x performance regression in the typical use case
compared to 3.0.7. The new fix uses existing constant time
code paths, and restores the previous performance level while
fully eliminating all existing timing side channels.
The fix was developed by Bernd Edlinger with testing support
by Hubert Kario.
*Bernd Edlinger*
* Added an "advanced" command mode to s_client. Use this with the "-adv"
option. The old "basic" command mode recognises certain letters that must
always appear at the start of a line and cannot be escaped. The advanced

View File

@ -381,25 +381,33 @@ BN_ULONG bn_sub_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
#ifndef OPENSSL_SMALL_FOOTPRINT
while (n & ~3) {
t1 = a[0];
t2 = b[0];
r[0] = (t1 - t2 - c) & BN_MASK2;
if (t1 != t2)
c = (t1 < t2);
t2 = (t1 - c) & BN_MASK2;
c = (t2 > t1);
t1 = b[0];
t1 = (t2 - t1) & BN_MASK2;
r[0] = t1;
c += (t1 > t2);
t1 = a[1];
t2 = b[1];
r[1] = (t1 - t2 - c) & BN_MASK2;
if (t1 != t2)
c = (t1 < t2);
t2 = (t1 - c) & BN_MASK2;
c = (t2 > t1);
t1 = b[1];
t1 = (t2 - t1) & BN_MASK2;
r[1] = t1;
c += (t1 > t2);
t1 = a[2];
t2 = b[2];
r[2] = (t1 - t2 - c) & BN_MASK2;
if (t1 != t2)
c = (t1 < t2);
t2 = (t1 - c) & BN_MASK2;
c = (t2 > t1);
t1 = b[2];
t1 = (t2 - t1) & BN_MASK2;
r[2] = t1;
c += (t1 > t2);
t1 = a[3];
t2 = b[3];
r[3] = (t1 - t2 - c) & BN_MASK2;
if (t1 != t2)
c = (t1 < t2);
t2 = (t1 - c) & BN_MASK2;
c = (t2 > t1);
t1 = b[3];
t1 = (t2 - t1) & BN_MASK2;
r[3] = t1;
c += (t1 > t2);
a += 4;
b += 4;
r += 4;
@ -408,10 +416,12 @@ BN_ULONG bn_sub_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
#endif
while (n) {
t1 = a[0];
t2 = b[0];
r[0] = (t1 - t2 - c) & BN_MASK2;
if (t1 != t2)
c = (t1 < t2);
t2 = (t1 - c) & BN_MASK2;
c = (t2 > t1);
t1 = b[0];
t1 = (t2 - t1) & BN_MASK2;
r[0] = t1;
c += (t1 > t2);
a++;
b++;
r++;
@ -441,7 +451,7 @@ BN_ULONG bn_sub_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
t += c0; /* no carry */ \
c0 = (BN_ULONG)Lw(t); \
hi = (BN_ULONG)Hw(t); \
c1 = (c1+hi)&BN_MASK2; if (c1<hi) c2++; \
c1 = (c1+hi)&BN_MASK2; c2 += (c1<hi); \
} while(0)
# define mul_add_c2(a,b,c0,c1,c2) do { \
@ -450,11 +460,11 @@ BN_ULONG bn_sub_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
BN_ULLONG tt = t+c0; /* no carry */ \
c0 = (BN_ULONG)Lw(tt); \
hi = (BN_ULONG)Hw(tt); \
c1 = (c1+hi)&BN_MASK2; if (c1<hi) c2++; \
c1 = (c1+hi)&BN_MASK2; c2 += (c1<hi); \
t += c0; /* no carry */ \
c0 = (BN_ULONG)Lw(t); \
hi = (BN_ULONG)Hw(t); \
c1 = (c1+hi)&BN_MASK2; if (c1<hi) c2++; \
c1 = (c1+hi)&BN_MASK2; c2 += (c1<hi); \
} while(0)
# define sqr_add_c(a,i,c0,c1,c2) do { \
@ -463,7 +473,7 @@ BN_ULONG bn_sub_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
t += c0; /* no carry */ \
c0 = (BN_ULONG)Lw(t); \
hi = (BN_ULONG)Hw(t); \
c1 = (c1+hi)&BN_MASK2; if (c1<hi) c2++; \
c1 = (c1+hi)&BN_MASK2; c2 += (c1<hi); \
} while(0)
# define sqr_add_c2(a,i,j,c0,c1,c2) \
@ -478,26 +488,26 @@ BN_ULONG bn_sub_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
BN_ULONG ta = (a), tb = (b); \
BN_ULONG lo, hi; \
BN_UMULT_LOHI(lo,hi,ta,tb); \
c0 += lo; hi += (c0<lo)?1:0; \
c1 += hi; c2 += (c1<hi)?1:0; \
c0 += lo; hi += (c0<lo); \
c1 += hi; c2 += (c1<hi); \
} while(0)
# define mul_add_c2(a,b,c0,c1,c2) do { \
BN_ULONG ta = (a), tb = (b); \
BN_ULONG lo, hi, tt; \
BN_UMULT_LOHI(lo,hi,ta,tb); \
c0 += lo; tt = hi+((c0<lo)?1:0); \
c1 += tt; c2 += (c1<tt)?1:0; \
c0 += lo; hi += (c0<lo)?1:0; \
c1 += hi; c2 += (c1<hi)?1:0; \
c0 += lo; tt = hi + (c0<lo); \
c1 += tt; c2 += (c1<tt); \
c0 += lo; hi += (c0<lo); \
c1 += hi; c2 += (c1<hi); \
} while(0)
# define sqr_add_c(a,i,c0,c1,c2) do { \
BN_ULONG ta = (a)[i]; \
BN_ULONG lo, hi; \
BN_UMULT_LOHI(lo,hi,ta,ta); \
c0 += lo; hi += (c0<lo)?1:0; \
c1 += hi; c2 += (c1<hi)?1:0; \
c0 += lo; hi += (c0<lo); \
c1 += hi; c2 += (c1<hi); \
} while(0)
# define sqr_add_c2(a,i,j,c0,c1,c2) \
@ -512,26 +522,26 @@ BN_ULONG bn_sub_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
BN_ULONG ta = (a), tb = (b); \
BN_ULONG lo = ta * tb; \
BN_ULONG hi = BN_UMULT_HIGH(ta,tb); \
c0 += lo; hi += (c0<lo)?1:0; \
c1 += hi; c2 += (c1<hi)?1:0; \
c0 += lo; hi += (c0<lo); \
c1 += hi; c2 += (c1<hi); \
} while(0)
# define mul_add_c2(a,b,c0,c1,c2) do { \
BN_ULONG ta = (a), tb = (b), tt; \
BN_ULONG lo = ta * tb; \
BN_ULONG hi = BN_UMULT_HIGH(ta,tb); \
c0 += lo; tt = hi + ((c0<lo)?1:0); \
c1 += tt; c2 += (c1<tt)?1:0; \
c0 += lo; hi += (c0<lo)?1:0; \
c1 += hi; c2 += (c1<hi)?1:0; \
c0 += lo; tt = hi + (c0<lo); \
c1 += tt; c2 += (c1<tt); \
c0 += lo; hi += (c0<lo); \
c1 += hi; c2 += (c1<hi); \
} while(0)
# define sqr_add_c(a,i,c0,c1,c2) do { \
BN_ULONG ta = (a)[i]; \
BN_ULONG lo = ta * ta; \
BN_ULONG hi = BN_UMULT_HIGH(ta,ta); \
c0 += lo; hi += (c0<lo)?1:0; \
c1 += hi; c2 += (c1<hi)?1:0; \
c0 += lo; hi += (c0<lo); \
c1 += hi; c2 += (c1<hi); \
} while(0)
# define sqr_add_c2(a,i,j,c0,c1,c2) \
@ -546,8 +556,8 @@ BN_ULONG bn_sub_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
BN_ULONG lo = LBITS(a), hi = HBITS(a); \
BN_ULONG bl = LBITS(b), bh = HBITS(b); \
mul64(lo,hi,bl,bh); \
c0 = (c0+lo)&BN_MASK2; if (c0<lo) hi++; \
c1 = (c1+hi)&BN_MASK2; if (c1<hi) c2++; \
c0 = (c0+lo)&BN_MASK2; hi += (c0<lo); \
c1 = (c1+hi)&BN_MASK2; c2 += (c1<hi); \
} while(0)
# define mul_add_c2(a,b,c0,c1,c2) do { \
@ -556,17 +566,17 @@ BN_ULONG bn_sub_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
BN_ULONG bl = LBITS(b), bh = HBITS(b); \
mul64(lo,hi,bl,bh); \
tt = hi; \
c0 = (c0+lo)&BN_MASK2; if (c0<lo) tt++; \
c1 = (c1+tt)&BN_MASK2; if (c1<tt) c2++; \
c0 = (c0+lo)&BN_MASK2; if (c0<lo) hi++; \
c1 = (c1+hi)&BN_MASK2; if (c1<hi) c2++; \
c0 = (c0+lo)&BN_MASK2; tt += (c0<lo); \
c1 = (c1+tt)&BN_MASK2; c2 += (c1<tt); \
c0 = (c0+lo)&BN_MASK2; hi += (c0<lo); \
c1 = (c1+hi)&BN_MASK2; c2 += (c1<hi); \
} while(0)
# define sqr_add_c(a,i,c0,c1,c2) do { \
BN_ULONG lo, hi; \
sqr64(lo,hi,(a)[i]); \
c0 = (c0+lo)&BN_MASK2; if (c0<lo) hi++; \
c1 = (c1+hi)&BN_MASK2; if (c1<hi) c2++; \
c0 = (c0+lo)&BN_MASK2; hi += (c0<lo); \
c1 = (c1+hi)&BN_MASK2; c2 += (c1<hi); \
} while(0)
# define sqr_add_c2(a,i,j,c0,c1,c2) \

View File

@ -189,7 +189,8 @@ int BN_BLINDING_invert_ex(BIGNUM *n, const BIGNUM *r, BN_BLINDING *b,
n->top = (int)(rtop & ~mask) | (ntop & mask);
n->flags |= (BN_FLG_FIXED_TOP & ~mask);
}
ret = BN_mod_mul_montgomery(n, n, r, b->m_ctx, ctx);
ret = bn_mul_mont_fixed_top(n, n, r, b->m_ctx, ctx);
bn_correct_top_consttime(n);
} else {
ret = BN_mod_mul(n, n, r, b->mod, ctx);
}

View File

@ -1106,6 +1106,28 @@ BIGNUM *bn_wexpand(BIGNUM *a, int words)
return (words <= a->dmax) ? a : bn_expand2(a, words);
}
void bn_correct_top_consttime(BIGNUM *a)
{
int j, atop;
BN_ULONG limb;
unsigned int mask;
for (j = 0, atop = 0; j < a->dmax; j++) {
limb = a->d[j];
limb |= 0 - limb;
limb >>= BN_BITS2 - 1;
limb = 0 - limb;
mask = (unsigned int)limb;
mask &= constant_time_msb(j - a->top);
atop = constant_time_select_int(mask, j + 1, atop);
}
mask = constant_time_eq_int(atop, 0);
a->top = atop;
a->neg = constant_time_select_int(mask, 0, a->neg);
a->flags &= ~BN_FLG_FIXED_TOP;
}
void bn_correct_top(BIGNUM *a)
{
BN_ULONG *ftl;

View File

@ -525,10 +525,10 @@ unsigned __int64 _umul128(unsigned __int64 a, unsigned __int64 b,
ret = (r); \
BN_UMULT_LOHI(low,high,w,tmp); \
ret += (c); \
(c) = (ret<(c))?1:0; \
(c) = (ret<(c)); \
(c) += high; \
ret += low; \
(c) += (ret<low)?1:0; \
(c) += (ret<low); \
(r) = ret; \
}
@ -537,7 +537,7 @@ unsigned __int64 _umul128(unsigned __int64 a, unsigned __int64 b,
BN_UMULT_LOHI(low,high,w,ta); \
ret = low + (c); \
(c) = high; \
(c) += (ret<low)?1:0; \
(c) += (ret<low); \
(r) = ret; \
}
@ -553,10 +553,10 @@ unsigned __int64 _umul128(unsigned __int64 a, unsigned __int64 b,
high= BN_UMULT_HIGH(w,tmp); \
ret += (c); \
low = (w) * tmp; \
(c) = (ret<(c))?1:0; \
(c) = (ret<(c)); \
(c) += high; \
ret += low; \
(c) += (ret<low)?1:0; \
(c) += (ret<low); \
(r) = ret; \
}
@ -566,7 +566,7 @@ unsigned __int64 _umul128(unsigned __int64 a, unsigned __int64 b,
high= BN_UMULT_HIGH(w,ta); \
ret = low + (c); \
(c) = high; \
(c) += (ret<low)?1:0; \
(c) += (ret<low); \
(r) = ret; \
}
@ -599,10 +599,10 @@ unsigned __int64 _umul128(unsigned __int64 a, unsigned __int64 b,
lt=(bl)*(lt); \
m1=(bl)*(ht); \
ht =(bh)*(ht); \
m=(m+m1)&BN_MASK2; if (m < m1) ht+=L2HBITS((BN_ULONG)1); \
m=(m+m1)&BN_MASK2; ht += L2HBITS((BN_ULONG)(m < m1)); \
ht+=HBITS(m); \
m1=L2HBITS(m); \
lt=(lt+m1)&BN_MASK2; if (lt < m1) ht++; \
lt=(lt+m1)&BN_MASK2; ht += (lt < m1); \
(l)=lt; \
(h)=ht; \
}
@ -619,7 +619,7 @@ unsigned __int64 _umul128(unsigned __int64 a, unsigned __int64 b,
h*=h; \
h+=(m&BN_MASK2h1)>>(BN_BITS4-1); \
m =(m&BN_MASK2l)<<(BN_BITS4+1); \
l=(l+m)&BN_MASK2; if (l < m) h++; \
l=(l+m)&BN_MASK2; h += (l < m); \
(lo)=l; \
(ho)=h; \
}
@ -633,9 +633,9 @@ unsigned __int64 _umul128(unsigned __int64 a, unsigned __int64 b,
mul64(l,h,(bl),(bh)); \
\
/* non-multiply part */ \
l=(l+(c))&BN_MASK2; if (l < (c)) h++; \
l=(l+(c))&BN_MASK2; h += (l < (c)); \
(c)=(r); \
l=(l+(c))&BN_MASK2; if (l < (c)) h++; \
l=(l+(c))&BN_MASK2; h += (l < (c)); \
(c)=h&BN_MASK2; \
(r)=l; \
}
@ -649,7 +649,7 @@ unsigned __int64 _umul128(unsigned __int64 a, unsigned __int64 b,
mul64(l,h,(bl),(bh)); \
\
/* non-multiply part */ \
l+=(c); if ((l&BN_MASK2) < (c)) h++; \
l+=(c); h += ((l&BN_MASK2) < (c)); \
(c)=h&BN_MASK2; \
(r)=l&BN_MASK2; \
}
@ -679,7 +679,7 @@ BN_ULONG bn_sub_part_words(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
int cl, int dl);
int bn_mul_mont(BN_ULONG *rp, const BN_ULONG *ap, const BN_ULONG *bp,
const BN_ULONG *np, const BN_ULONG *n0, int num);
void bn_correct_top_consttime(BIGNUM *a);
BIGNUM *int_bn_mod_inverse(BIGNUM *in,
const BIGNUM *a, const BIGNUM *n, BN_CTX *ctx,
int *noinv);

View File

@ -257,6 +257,7 @@ static int rsa_blinding_invert(BN_BLINDING *b, BIGNUM *f, BIGNUM *unblind,
* will only read the modulus from BN_BLINDING. In both cases it's safe
* to access the blinding without a lock.
*/
BN_set_flags(f, BN_FLG_CONSTTIME);
return BN_BLINDING_invert_ex(f, unblind, b, ctx);
}
@ -536,6 +537,11 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
goto err;
}
if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
rsa->n, ctx))
goto err;
if (!(rsa->flags & RSA_FLAG_NO_BLINDING)) {
blinding = rsa_get_blinding(rsa, &local_blinding, ctx);
if (blinding == NULL) {
@ -573,13 +579,6 @@ static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
goto err;
}
BN_with_flags(d, rsa->d, BN_FLG_CONSTTIME);
if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
rsa->n, ctx)) {
BN_free(d);
goto err;
}
if (!rsa->meth->bn_mod_exp(ret, f, d, rsa->n, ctx,
rsa->_method_mod_n)) {
BN_free(d);