Extend _sqrtprod() to cover the full range of inputs. Add tests. (GH-107855)

This commit is contained in:
Raymond Hettinger 2023-08-11 17:19:19 +01:00 committed by GitHub
parent 637f7ff2c6
commit 52e0797f8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 97 additions and 6 deletions

View File

@ -137,6 +137,7 @@ from decimal import Decimal
from itertools import count, groupby, repeat
from bisect import bisect_left, bisect_right
from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum, sumprod
from math import isfinite, isinf
from functools import reduce
from operator import itemgetter
from collections import Counter, namedtuple, defaultdict
@ -1005,14 +1006,25 @@ def _mean_stdev(data):
return float(xbar), float(xbar) / float(ss)
def _sqrtprod(x: float, y: float) -> float:
"Return sqrt(x * y) computed with high accuracy."
# Square root differential correction:
# https://www.wolframalpha.com/input/?i=Maclaurin+series+sqrt%28h**2+%2B+x%29+at+x%3D0
"Return sqrt(x * y) computed with improved accuracy and without overflow/underflow."
h = sqrt(x * y)
if not isfinite(h):
if isinf(h) and not isinf(x) and not isinf(y):
# Finite inputs overflowed, so scale down, and recompute.
scale = 2.0 ** -512 # sqrt(1 / sys.float_info.max)
return _sqrtprod(scale * x, scale * y) / scale
return h
if not h:
return 0.0
x = sumprod((x, h), (y, -h))
return h + x / (2.0 * h)
if x and y:
# Non-zero inputs underflowed, so scale up, and recompute.
# Scale: 1 / sqrt(sys.float_info.min * sys.float_info.epsilon)
scale = 2.0 ** 537
return _sqrtprod(scale * x, scale * y) / scale
return h
# Improve accuracy with a differential correction.
# https://www.wolframalpha.com/input/?i=Maclaurin+series+sqrt%28h**2+%2B+x%29+at+x%3D0
d = sumprod((x, h), (y, -h))
return h + d / (2.0 * h)
# === Statistics for relations between two inputs ===

View File

@ -28,6 +28,12 @@ import statistics
# === Helper functions and class ===
# Test copied from Lib/test/test_math.py
# detect evidence of double-rounding: fsum is not always correctly
# rounded on machines that suffer from double rounding.
x, y = 1e16, 2.9999 # use temporary values to defeat peephole optimizer
HAVE_DOUBLE_ROUNDING = (x + y == 1e16 + 4)
def sign(x):
"""Return -1.0 for negatives, including -0.0, otherwise +1.0."""
return math.copysign(1, x)
@ -2564,6 +2570,79 @@ class TestCorrelationAndCovariance(unittest.TestCase):
self.assertAlmostEqual(statistics.correlation(x, y), 1)
self.assertAlmostEqual(statistics.covariance(x, y), 0.1)
def test_sqrtprod_helper_function_fundamentals(self):
# Verify that results are close to sqrt(x * y)
for i in range(100):
x = random.expovariate()
y = random.expovariate()
expected = math.sqrt(x * y)
actual = statistics._sqrtprod(x, y)
with self.subTest(x=x, y=y, expected=expected, actual=actual):
self.assertAlmostEqual(expected, actual)
x, y, target = 0.8035720646477457, 0.7957468097636939, 0.7996498651651661
self.assertEqual(statistics._sqrtprod(x, y), target)
self.assertNotEqual(math.sqrt(x * y), target)
# Test that range extremes avoid underflow and overflow
smallest = sys.float_info.min * sys.float_info.epsilon
self.assertEqual(statistics._sqrtprod(smallest, smallest), smallest)
biggest = sys.float_info.max
self.assertEqual(statistics._sqrtprod(biggest, biggest), biggest)
# Check special values and the sign of the result
special_values = [0.0, -0.0, 1.0, -1.0, 4.0, -4.0,
math.nan, -math.nan, math.inf, -math.inf]
for x, y in itertools.product(special_values, repeat=2):
try:
expected = math.sqrt(x * y)
except ValueError:
expected = 'ValueError'
try:
actual = statistics._sqrtprod(x, y)
except ValueError:
actual = 'ValueError'
with self.subTest(x=x, y=y, expected=expected, actual=actual):
if isinstance(expected, str) and expected == 'ValueError':
self.assertEqual(actual, 'ValueError')
continue
self.assertIsInstance(actual, float)
if math.isnan(expected):
self.assertTrue(math.isnan(actual))
continue
self.assertEqual(actual, expected)
self.assertEqual(sign(actual), sign(expected))
@requires_IEEE_754
@unittest.skipIf(HAVE_DOUBLE_ROUNDING,
"accuracy not guaranteed on machines with double rounding")
@support.cpython_only # Allow for a weaker sumprod() implmentation
def test_sqrtprod_helper_function_improved_accuracy(self):
# Test a known example where accuracy is improved
x, y, target = 0.8035720646477457, 0.7957468097636939, 0.7996498651651661
self.assertEqual(statistics._sqrtprod(x, y), target)
self.assertNotEqual(math.sqrt(x * y), target)
def reference_value(x: float, y: float) -> float:
x = decimal.Decimal(x)
y = decimal.Decimal(y)
with decimal.localcontext() as ctx:
ctx.prec = 200
return float((x * y).sqrt())
# Verify that the new function with improved accuracy
# agrees with a reference value more often than old version.
new_agreements = 0
old_agreements = 0
for i in range(10_000):
x = random.expovariate()
y = random.expovariate()
new = statistics._sqrtprod(x, y)
old = math.sqrt(x * y)
ref = reference_value(x, y)
new_agreements += (new == ref)
old_agreements += (old == ref)
self.assertGreater(new_agreements, old_agreements)
def test_correlation_spearman(self):
# https://statistics.laerd.com/statistical-guides/spearmans-rank-order-correlation-statistical-guide-2.php