From 52e0797f8e1c631eecf24cb3f997ace336f52271 Mon Sep 17 00:00:00 2001 From: Raymond Hettinger Date: Fri, 11 Aug 2023 17:19:19 +0100 Subject: [PATCH] Extend _sqrtprod() to cover the full range of inputs. Add tests. (GH-107855) --- Lib/statistics.py | 24 ++++++++--- Lib/test/test_statistics.py | 79 +++++++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 6 deletions(-) diff --git a/Lib/statistics.py b/Lib/statistics.py index 93c44f1f13f..a8036e9928c 100644 --- a/Lib/statistics.py +++ b/Lib/statistics.py @@ -137,6 +137,7 @@ from decimal import Decimal from itertools import count, groupby, repeat from bisect import bisect_left, bisect_right from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum, sumprod +from math import isfinite, isinf from functools import reduce from operator import itemgetter from collections import Counter, namedtuple, defaultdict @@ -1005,14 +1006,25 @@ def _mean_stdev(data): return float(xbar), float(xbar) / float(ss) def _sqrtprod(x: float, y: float) -> float: - "Return sqrt(x * y) computed with high accuracy." - # Square root differential correction: - # https://www.wolframalpha.com/input/?i=Maclaurin+series+sqrt%28h**2+%2B+x%29+at+x%3D0 + "Return sqrt(x * y) computed with improved accuracy and without overflow/underflow." h = sqrt(x * y) + if not isfinite(h): + if isinf(h) and not isinf(x) and not isinf(y): + # Finite inputs overflowed, so scale down, and recompute. + scale = 2.0 ** -512 # sqrt(1 / sys.float_info.max) + return _sqrtprod(scale * x, scale * y) / scale + return h if not h: - return 0.0 - x = sumprod((x, h), (y, -h)) - return h + x / (2.0 * h) + if x and y: + # Non-zero inputs underflowed, so scale up, and recompute. + # Scale: 1 / sqrt(sys.float_info.min * sys.float_info.epsilon) + scale = 2.0 ** 537 + return _sqrtprod(scale * x, scale * y) / scale + return h + # Improve accuracy with a differential correction. + # https://www.wolframalpha.com/input/?i=Maclaurin+series+sqrt%28h**2+%2B+x%29+at+x%3D0 + d = sumprod((x, h), (y, -h)) + return h + d / (2.0 * h) # === Statistics for relations between two inputs === diff --git a/Lib/test/test_statistics.py b/Lib/test/test_statistics.py index f0fa6454b1f..aa2cf2b1edc 100644 --- a/Lib/test/test_statistics.py +++ b/Lib/test/test_statistics.py @@ -28,6 +28,12 @@ import statistics # === Helper functions and class === +# Test copied from Lib/test/test_math.py +# detect evidence of double-rounding: fsum is not always correctly +# rounded on machines that suffer from double rounding. +x, y = 1e16, 2.9999 # use temporary values to defeat peephole optimizer +HAVE_DOUBLE_ROUNDING = (x + y == 1e16 + 4) + def sign(x): """Return -1.0 for negatives, including -0.0, otherwise +1.0.""" return math.copysign(1, x) @@ -2564,6 +2570,79 @@ class TestCorrelationAndCovariance(unittest.TestCase): self.assertAlmostEqual(statistics.correlation(x, y), 1) self.assertAlmostEqual(statistics.covariance(x, y), 0.1) + def test_sqrtprod_helper_function_fundamentals(self): + # Verify that results are close to sqrt(x * y) + for i in range(100): + x = random.expovariate() + y = random.expovariate() + expected = math.sqrt(x * y) + actual = statistics._sqrtprod(x, y) + with self.subTest(x=x, y=y, expected=expected, actual=actual): + self.assertAlmostEqual(expected, actual) + + x, y, target = 0.8035720646477457, 0.7957468097636939, 0.7996498651651661 + self.assertEqual(statistics._sqrtprod(x, y), target) + self.assertNotEqual(math.sqrt(x * y), target) + + # Test that range extremes avoid underflow and overflow + smallest = sys.float_info.min * sys.float_info.epsilon + self.assertEqual(statistics._sqrtprod(smallest, smallest), smallest) + biggest = sys.float_info.max + self.assertEqual(statistics._sqrtprod(biggest, biggest), biggest) + + # Check special values and the sign of the result + special_values = [0.0, -0.0, 1.0, -1.0, 4.0, -4.0, + math.nan, -math.nan, math.inf, -math.inf] + for x, y in itertools.product(special_values, repeat=2): + try: + expected = math.sqrt(x * y) + except ValueError: + expected = 'ValueError' + try: + actual = statistics._sqrtprod(x, y) + except ValueError: + actual = 'ValueError' + with self.subTest(x=x, y=y, expected=expected, actual=actual): + if isinstance(expected, str) and expected == 'ValueError': + self.assertEqual(actual, 'ValueError') + continue + self.assertIsInstance(actual, float) + if math.isnan(expected): + self.assertTrue(math.isnan(actual)) + continue + self.assertEqual(actual, expected) + self.assertEqual(sign(actual), sign(expected)) + + @requires_IEEE_754 + @unittest.skipIf(HAVE_DOUBLE_ROUNDING, + "accuracy not guaranteed on machines with double rounding") + @support.cpython_only # Allow for a weaker sumprod() implmentation + def test_sqrtprod_helper_function_improved_accuracy(self): + # Test a known example where accuracy is improved + x, y, target = 0.8035720646477457, 0.7957468097636939, 0.7996498651651661 + self.assertEqual(statistics._sqrtprod(x, y), target) + self.assertNotEqual(math.sqrt(x * y), target) + + def reference_value(x: float, y: float) -> float: + x = decimal.Decimal(x) + y = decimal.Decimal(y) + with decimal.localcontext() as ctx: + ctx.prec = 200 + return float((x * y).sqrt()) + + # Verify that the new function with improved accuracy + # agrees with a reference value more often than old version. + new_agreements = 0 + old_agreements = 0 + for i in range(10_000): + x = random.expovariate() + y = random.expovariate() + new = statistics._sqrtprod(x, y) + old = math.sqrt(x * y) + ref = reference_value(x, y) + new_agreements += (new == ref) + old_agreements += (old == ref) + self.assertGreater(new_agreements, old_agreements) def test_correlation_spearman(self): # https://statistics.laerd.com/statistical-guides/spearmans-rank-order-correlation-statistical-guide-2.php