mirror of
https://github.com/python/cpython.git
synced 2024-11-25 02:44:06 +08:00
Extend _sqrtprod() to cover the full range of inputs. Add tests. (GH-107855)
This commit is contained in:
parent
637f7ff2c6
commit
52e0797f8e
@ -137,6 +137,7 @@ from decimal import Decimal
|
||||
from itertools import count, groupby, repeat
|
||||
from bisect import bisect_left, bisect_right
|
||||
from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum, sumprod
|
||||
from math import isfinite, isinf
|
||||
from functools import reduce
|
||||
from operator import itemgetter
|
||||
from collections import Counter, namedtuple, defaultdict
|
||||
@ -1005,14 +1006,25 @@ def _mean_stdev(data):
|
||||
return float(xbar), float(xbar) / float(ss)
|
||||
|
||||
def _sqrtprod(x: float, y: float) -> float:
|
||||
"Return sqrt(x * y) computed with high accuracy."
|
||||
# Square root differential correction:
|
||||
# https://www.wolframalpha.com/input/?i=Maclaurin+series+sqrt%28h**2+%2B+x%29+at+x%3D0
|
||||
"Return sqrt(x * y) computed with improved accuracy and without overflow/underflow."
|
||||
h = sqrt(x * y)
|
||||
if not isfinite(h):
|
||||
if isinf(h) and not isinf(x) and not isinf(y):
|
||||
# Finite inputs overflowed, so scale down, and recompute.
|
||||
scale = 2.0 ** -512 # sqrt(1 / sys.float_info.max)
|
||||
return _sqrtprod(scale * x, scale * y) / scale
|
||||
return h
|
||||
if not h:
|
||||
return 0.0
|
||||
x = sumprod((x, h), (y, -h))
|
||||
return h + x / (2.0 * h)
|
||||
if x and y:
|
||||
# Non-zero inputs underflowed, so scale up, and recompute.
|
||||
# Scale: 1 / sqrt(sys.float_info.min * sys.float_info.epsilon)
|
||||
scale = 2.0 ** 537
|
||||
return _sqrtprod(scale * x, scale * y) / scale
|
||||
return h
|
||||
# Improve accuracy with a differential correction.
|
||||
# https://www.wolframalpha.com/input/?i=Maclaurin+series+sqrt%28h**2+%2B+x%29+at+x%3D0
|
||||
d = sumprod((x, h), (y, -h))
|
||||
return h + d / (2.0 * h)
|
||||
|
||||
|
||||
# === Statistics for relations between two inputs ===
|
||||
|
@ -28,6 +28,12 @@ import statistics
|
||||
|
||||
# === Helper functions and class ===
|
||||
|
||||
# Test copied from Lib/test/test_math.py
|
||||
# detect evidence of double-rounding: fsum is not always correctly
|
||||
# rounded on machines that suffer from double rounding.
|
||||
x, y = 1e16, 2.9999 # use temporary values to defeat peephole optimizer
|
||||
HAVE_DOUBLE_ROUNDING = (x + y == 1e16 + 4)
|
||||
|
||||
def sign(x):
|
||||
"""Return -1.0 for negatives, including -0.0, otherwise +1.0."""
|
||||
return math.copysign(1, x)
|
||||
@ -2564,6 +2570,79 @@ class TestCorrelationAndCovariance(unittest.TestCase):
|
||||
self.assertAlmostEqual(statistics.correlation(x, y), 1)
|
||||
self.assertAlmostEqual(statistics.covariance(x, y), 0.1)
|
||||
|
||||
def test_sqrtprod_helper_function_fundamentals(self):
|
||||
# Verify that results are close to sqrt(x * y)
|
||||
for i in range(100):
|
||||
x = random.expovariate()
|
||||
y = random.expovariate()
|
||||
expected = math.sqrt(x * y)
|
||||
actual = statistics._sqrtprod(x, y)
|
||||
with self.subTest(x=x, y=y, expected=expected, actual=actual):
|
||||
self.assertAlmostEqual(expected, actual)
|
||||
|
||||
x, y, target = 0.8035720646477457, 0.7957468097636939, 0.7996498651651661
|
||||
self.assertEqual(statistics._sqrtprod(x, y), target)
|
||||
self.assertNotEqual(math.sqrt(x * y), target)
|
||||
|
||||
# Test that range extremes avoid underflow and overflow
|
||||
smallest = sys.float_info.min * sys.float_info.epsilon
|
||||
self.assertEqual(statistics._sqrtprod(smallest, smallest), smallest)
|
||||
biggest = sys.float_info.max
|
||||
self.assertEqual(statistics._sqrtprod(biggest, biggest), biggest)
|
||||
|
||||
# Check special values and the sign of the result
|
||||
special_values = [0.0, -0.0, 1.0, -1.0, 4.0, -4.0,
|
||||
math.nan, -math.nan, math.inf, -math.inf]
|
||||
for x, y in itertools.product(special_values, repeat=2):
|
||||
try:
|
||||
expected = math.sqrt(x * y)
|
||||
except ValueError:
|
||||
expected = 'ValueError'
|
||||
try:
|
||||
actual = statistics._sqrtprod(x, y)
|
||||
except ValueError:
|
||||
actual = 'ValueError'
|
||||
with self.subTest(x=x, y=y, expected=expected, actual=actual):
|
||||
if isinstance(expected, str) and expected == 'ValueError':
|
||||
self.assertEqual(actual, 'ValueError')
|
||||
continue
|
||||
self.assertIsInstance(actual, float)
|
||||
if math.isnan(expected):
|
||||
self.assertTrue(math.isnan(actual))
|
||||
continue
|
||||
self.assertEqual(actual, expected)
|
||||
self.assertEqual(sign(actual), sign(expected))
|
||||
|
||||
@requires_IEEE_754
|
||||
@unittest.skipIf(HAVE_DOUBLE_ROUNDING,
|
||||
"accuracy not guaranteed on machines with double rounding")
|
||||
@support.cpython_only # Allow for a weaker sumprod() implmentation
|
||||
def test_sqrtprod_helper_function_improved_accuracy(self):
|
||||
# Test a known example where accuracy is improved
|
||||
x, y, target = 0.8035720646477457, 0.7957468097636939, 0.7996498651651661
|
||||
self.assertEqual(statistics._sqrtprod(x, y), target)
|
||||
self.assertNotEqual(math.sqrt(x * y), target)
|
||||
|
||||
def reference_value(x: float, y: float) -> float:
|
||||
x = decimal.Decimal(x)
|
||||
y = decimal.Decimal(y)
|
||||
with decimal.localcontext() as ctx:
|
||||
ctx.prec = 200
|
||||
return float((x * y).sqrt())
|
||||
|
||||
# Verify that the new function with improved accuracy
|
||||
# agrees with a reference value more often than old version.
|
||||
new_agreements = 0
|
||||
old_agreements = 0
|
||||
for i in range(10_000):
|
||||
x = random.expovariate()
|
||||
y = random.expovariate()
|
||||
new = statistics._sqrtprod(x, y)
|
||||
old = math.sqrt(x * y)
|
||||
ref = reference_value(x, y)
|
||||
new_agreements += (new == ref)
|
||||
old_agreements += (old == ref)
|
||||
self.assertGreater(new_agreements, old_agreements)
|
||||
|
||||
def test_correlation_spearman(self):
|
||||
# https://statistics.laerd.com/statistical-guides/spearmans-rank-order-correlation-statistical-guide-2.php
|
||||
|
Loading…
Reference in New Issue
Block a user