mirror of
https://github.com/python/cpython.git
synced 2024-11-27 11:55:13 +08:00
Rework integer overflow path in math.prod and add more tests (GH-11809)
The overflow check was relying on undefined behaviour as it was using the result of the multiplication to do the check, and once the overflow has already happened, any operation on the result is undefined behaviour. Some extra checks that exercise code paths related to this are also added.
This commit is contained in:
parent
62fa51f121
commit
0411411c6b
@ -1595,6 +1595,92 @@ class MathTests(unittest.TestCase):
|
||||
self.fail('Failures in test_mtestfile:\n ' +
|
||||
'\n '.join(failures))
|
||||
|
||||
def test_prod(self):
|
||||
prod = math.prod
|
||||
self.assertEqual(prod([]), 1)
|
||||
self.assertEqual(prod([], start=5), 5)
|
||||
self.assertEqual(prod(list(range(2,8))), 5040)
|
||||
self.assertEqual(prod(iter(list(range(2,8)))), 5040)
|
||||
self.assertEqual(prod(range(1, 10), start=10), 3628800)
|
||||
|
||||
self.assertEqual(prod([1, 2, 3, 4, 5]), 120)
|
||||
self.assertEqual(prod([1.0, 2.0, 3.0, 4.0, 5.0]), 120.0)
|
||||
self.assertEqual(prod([1, 2, 3, 4.0, 5.0]), 120.0)
|
||||
self.assertEqual(prod([1.0, 2.0, 3.0, 4, 5]), 120.0)
|
||||
|
||||
# Test overflow in fast-path for integers
|
||||
self.assertEqual(prod([1, 1, 2**32, 1, 1]), 2**32)
|
||||
# Test overflow in fast-path for floats
|
||||
self.assertEqual(prod([1.0, 1.0, 2**32, 1, 1]), float(2**32))
|
||||
|
||||
self.assertRaises(TypeError, prod)
|
||||
self.assertRaises(TypeError, prod, 42)
|
||||
self.assertRaises(TypeError, prod, ['a', 'b', 'c'])
|
||||
self.assertRaises(TypeError, prod, ['a', 'b', 'c'], '')
|
||||
self.assertRaises(TypeError, prod, [b'a', b'c'], b'')
|
||||
values = [bytearray(b'a'), bytearray(b'b')]
|
||||
self.assertRaises(TypeError, prod, values, bytearray(b''))
|
||||
self.assertRaises(TypeError, prod, [[1], [2], [3]])
|
||||
self.assertRaises(TypeError, prod, [{2:3}])
|
||||
self.assertRaises(TypeError, prod, [{2:3}]*2, {2:3})
|
||||
self.assertRaises(TypeError, prod, [[1], [2], [3]], [])
|
||||
with self.assertRaises(TypeError):
|
||||
prod([10, 20], [30, 40]) # start is a keyword-only argument
|
||||
|
||||
self.assertEqual(prod([0, 1, 2, 3]), 0)
|
||||
self.assertEqual(prod([1, 0, 2, 3]), 0)
|
||||
self.assertEqual(prod([1, 2, 3, 0]), 0)
|
||||
|
||||
def _naive_prod(iterable, start=1):
|
||||
for elem in iterable:
|
||||
start *= elem
|
||||
return start
|
||||
|
||||
# Big integers
|
||||
|
||||
iterable = range(1, 10000)
|
||||
self.assertEqual(prod(iterable), _naive_prod(iterable))
|
||||
iterable = range(-10000, -1)
|
||||
self.assertEqual(prod(iterable), _naive_prod(iterable))
|
||||
iterable = range(-1000, 1000)
|
||||
self.assertEqual(prod(iterable), 0)
|
||||
|
||||
# Big floats
|
||||
|
||||
iterable = [float(x) for x in range(1, 1000)]
|
||||
self.assertEqual(prod(iterable), _naive_prod(iterable))
|
||||
iterable = [float(x) for x in range(-1000, -1)]
|
||||
self.assertEqual(prod(iterable), _naive_prod(iterable))
|
||||
iterable = [float(x) for x in range(-1000, 1000)]
|
||||
self.assertIsNaN(prod(iterable))
|
||||
|
||||
# Float tests
|
||||
|
||||
self.assertIsNaN(prod([1, 2, 3, float("nan"), 2, 3]))
|
||||
self.assertIsNaN(prod([1, 0, float("nan"), 2, 3]))
|
||||
self.assertIsNaN(prod([1, float("nan"), 0, 3]))
|
||||
self.assertIsNaN(prod([1, float("inf"), float("nan"),3]))
|
||||
self.assertIsNaN(prod([1, float("-inf"), float("nan"),3]))
|
||||
self.assertIsNaN(prod([1, float("nan"), float("inf"),3]))
|
||||
self.assertIsNaN(prod([1, float("nan"), float("-inf"),3]))
|
||||
|
||||
self.assertEqual(prod([1, 2, 3, float('inf'),-3,4]), float('-inf'))
|
||||
self.assertEqual(prod([1, 2, 3, float('-inf'),-3,4]), float('inf'))
|
||||
|
||||
self.assertIsNaN(prod([1,2,0,float('inf'), -3, 4]))
|
||||
self.assertIsNaN(prod([1,2,0,float('-inf'), -3, 4]))
|
||||
self.assertIsNaN(prod([1, 2, 3, float('inf'), -3, 0, 3]))
|
||||
self.assertIsNaN(prod([1, 2, 3, float('-inf'), -3, 0, 2]))
|
||||
|
||||
# Type preservation
|
||||
|
||||
self.assertEqual(type(prod([1, 2, 3, 4, 5, 6])), int)
|
||||
self.assertEqual(type(prod([1, 2.0, 3, 4, 5, 6])), float)
|
||||
self.assertEqual(type(prod(range(1, 10000))), int)
|
||||
self.assertEqual(type(prod(range(1, 10000), start=1.0)), float)
|
||||
self.assertEqual(type(prod([1, decimal.Decimal(2.0), 3, 4, 5, 6])),
|
||||
decimal.Decimal)
|
||||
|
||||
# Custom assertions.
|
||||
|
||||
def assertIsNaN(self, value):
|
||||
@ -1724,41 +1810,6 @@ class IsCloseTests(unittest.TestCase):
|
||||
self.assertAllClose(fraction_examples, rel_tol=1e-8)
|
||||
self.assertAllNotClose(fraction_examples, rel_tol=1e-9)
|
||||
|
||||
def test_prod(self):
|
||||
prod = math.prod
|
||||
self.assertEqual(prod([]), 1)
|
||||
self.assertEqual(prod([], start=5), 5)
|
||||
self.assertEqual(prod(list(range(2,8))), 5040)
|
||||
self.assertEqual(prod(iter(list(range(2,8)))), 5040)
|
||||
self.assertEqual(prod(range(1, 10), start=10), 3628800)
|
||||
|
||||
self.assertEqual(prod([1, 2, 3, 4, 5]), 120)
|
||||
self.assertEqual(prod([1.0, 2.0, 3.0, 4.0, 5.0]), 120.0)
|
||||
self.assertEqual(prod([1, 2, 3, 4.0, 5.0]), 120.0)
|
||||
self.assertEqual(prod([1.0, 2.0, 3.0, 4, 5]), 120.0)
|
||||
|
||||
# Test overflow in fast-path for integers
|
||||
self.assertEqual(prod([1, 1, 2**32, 1, 1]), 2**32)
|
||||
# Test overflow in fast-path for floats
|
||||
self.assertEqual(prod([1.0, 1.0, 2**32, 1, 1]), float(2**32))
|
||||
|
||||
self.assertRaises(TypeError, prod)
|
||||
self.assertRaises(TypeError, prod, 42)
|
||||
self.assertRaises(TypeError, prod, ['a', 'b', 'c'])
|
||||
self.assertRaises(TypeError, prod, ['a', 'b', 'c'], '')
|
||||
self.assertRaises(TypeError, prod, [b'a', b'c'], b'')
|
||||
values = [bytearray(b'a'), bytearray(b'b')]
|
||||
self.assertRaises(TypeError, prod, values, bytearray(b''))
|
||||
self.assertRaises(TypeError, prod, [[1], [2], [3]])
|
||||
self.assertRaises(TypeError, prod, [{2:3}])
|
||||
self.assertRaises(TypeError, prod, [{2:3}]*2, {2:3})
|
||||
self.assertRaises(TypeError, prod, [[1], [2], [3]], [])
|
||||
with self.assertRaises(TypeError):
|
||||
prod([10, 20], [30, 40]) # start is a keyword-only argument
|
||||
|
||||
self.assertEqual(prod([0, 1, 2, 3]), 0)
|
||||
self.assertEqual(prod([1, 0, 2, 3]), 0)
|
||||
self.assertEqual(prod(range(10)), 0)
|
||||
|
||||
def test_main():
|
||||
from doctest import DocFileSuite
|
||||
|
@ -2493,6 +2493,55 @@ math_isclose_impl(PyObject *module, double a, double b, double rel_tol,
|
||||
(diff <= abs_tol));
|
||||
}
|
||||
|
||||
static inline int
|
||||
_check_long_mult_overflow(long a, long b) {
|
||||
|
||||
/* From Python2's int_mul code:
|
||||
|
||||
Integer overflow checking for * is painful: Python tried a couple ways, but
|
||||
they didn't work on all platforms, or failed in endcases (a product of
|
||||
-sys.maxint-1 has been a particular pain).
|
||||
|
||||
Here's another way:
|
||||
|
||||
The native long product x*y is either exactly right or *way* off, being
|
||||
just the last n bits of the true product, where n is the number of bits
|
||||
in a long (the delivered product is the true product plus i*2**n for
|
||||
some integer i).
|
||||
|
||||
The native double product (double)x * (double)y is subject to three
|
||||
rounding errors: on a sizeof(long)==8 box, each cast to double can lose
|
||||
info, and even on a sizeof(long)==4 box, the multiplication can lose info.
|
||||
But, unlike the native long product, it's not in *range* trouble: even
|
||||
if sizeof(long)==32 (256-bit longs), the product easily fits in the
|
||||
dynamic range of a double. So the leading 50 (or so) bits of the double
|
||||
product are correct.
|
||||
|
||||
We check these two ways against each other, and declare victory if they're
|
||||
approximately the same. Else, because the native long product is the only
|
||||
one that can lose catastrophic amounts of information, it's the native long
|
||||
product that must have overflowed.
|
||||
|
||||
*/
|
||||
|
||||
long longprod = (long)((unsigned long)a * b);
|
||||
double doubleprod = (double)a * (double)b;
|
||||
double doubled_longprod = (double)longprod;
|
||||
|
||||
if (doubled_longprod == doubleprod) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const double diff = doubled_longprod - doubleprod;
|
||||
const double absdiff = diff >= 0.0 ? diff : -diff;
|
||||
const double absprod = doubleprod >= 0.0 ? doubleprod : -doubleprod;
|
||||
|
||||
if (32.0 * absdiff <= absprod) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
/*[clinic input]
|
||||
math.prod
|
||||
@ -2558,11 +2607,8 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start)
|
||||
}
|
||||
if (PyLong_CheckExact(item)) {
|
||||
long b = PyLong_AsLongAndOverflow(item, &overflow);
|
||||
long x = i_result * b;
|
||||
/* Continue if there is no overflow */
|
||||
if (overflow == 0
|
||||
&& x < LONG_MAX && x > LONG_MIN
|
||||
&& !(b != 0 && x / b != i_result)) {
|
||||
if (overflow == 0 && !_check_long_mult_overflow(i_result, b)) {
|
||||
long x = i_result * b;
|
||||
i_result = x;
|
||||
Py_DECREF(item);
|
||||
continue;
|
||||
|
Loading…
Reference in New Issue
Block a user