mirror of
				https://github.com/python/cpython.git
				synced 2025-11-03 23:21:29 +00:00 
			
		
		
		
	Issue27181 add geometric mean.
This commit is contained in:
		
							parent
							
								
									e7fef52f98
								
							
						
					
					
						commit
						9a2be91c6b
					
				
					 2 changed files with 552 additions and 0 deletions
				
			
		| 
						 | 
					@ -303,6 +303,230 @@ def _fail_neg(values, errmsg='negative value'):
 | 
				
			||||||
        yield x
 | 
					        yield x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _nroot_NS:
 | 
				
			||||||
 | 
					    """Hands off! Don't touch!
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Everything inside this namespace (class) is an even-more-private
 | 
				
			||||||
 | 
					    implementation detail of the private _nth_root function.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    # This class exists only to be used as a namespace, for convenience
 | 
				
			||||||
 | 
					    # of being able to keep the related functions together, and to
 | 
				
			||||||
 | 
					    # collapse the group in an editor. If this were C# or C++, I would
 | 
				
			||||||
 | 
					    # use a Namespace, but the closest Python has is a class.
 | 
				
			||||||
 | 
					    #
 | 
				
			||||||
 | 
					    # FIXME possibly move this out into a separate module?
 | 
				
			||||||
 | 
					    # That feels like overkill, and may encourage people to treat it as
 | 
				
			||||||
 | 
					    # a public feature.
 | 
				
			||||||
 | 
					    def __init__(self):
 | 
				
			||||||
 | 
					        raise TypeError('namespace only, do not instantiate')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def nth_root(x, n):
 | 
				
			||||||
 | 
					        """Return the positive nth root of numeric x.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        This may be more accurate than ** or pow():
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        >>> math.pow(1000, 1.0/3)  #doctest:+SKIP
 | 
				
			||||||
 | 
					        9.999999999999998
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        >>> _nth_root(1000, 3)
 | 
				
			||||||
 | 
					        10.0
 | 
				
			||||||
 | 
					        >>> _nth_root(11**5, 5)
 | 
				
			||||||
 | 
					        11.0
 | 
				
			||||||
 | 
					        >>> _nth_root(2, 12)
 | 
				
			||||||
 | 
					        1.0594630943592953
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if not isinstance(n, int):
 | 
				
			||||||
 | 
					            raise TypeError('degree n must be an int')
 | 
				
			||||||
 | 
					        if n < 2:
 | 
				
			||||||
 | 
					            raise ValueError('degree n must be 2 or more')
 | 
				
			||||||
 | 
					        if isinstance(x, decimal.Decimal):
 | 
				
			||||||
 | 
					            return _nroot_NS.decimal_nroot(x, n)
 | 
				
			||||||
 | 
					        elif isinstance(x, numbers.Real):
 | 
				
			||||||
 | 
					            return _nroot_NS.float_nroot(x, n)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise TypeError('expected a number, got %s') % type(x).__name__
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def float_nroot(x, n):
 | 
				
			||||||
 | 
					        """Handle nth root of Reals, treated as a float."""
 | 
				
			||||||
 | 
					        assert isinstance(n, int) and n > 1
 | 
				
			||||||
 | 
					        if x < 0:
 | 
				
			||||||
 | 
					            if n%2 == 0:
 | 
				
			||||||
 | 
					                raise ValueError('domain error: even root of negative number')
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                return -_nroot_NS.nroot(-x, n)
 | 
				
			||||||
 | 
					        elif x == 0:
 | 
				
			||||||
 | 
					            return math.copysign(0.0, x)
 | 
				
			||||||
 | 
					        elif x > 0:
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                isinfinity = math.isinf(x)
 | 
				
			||||||
 | 
					            except OverflowError:
 | 
				
			||||||
 | 
					                return _nroot_NS.bignum_nroot(x, n)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                if isinfinity:
 | 
				
			||||||
 | 
					                    return float('inf')
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    return _nroot_NS.nroot(x, n)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            assert math.isnan(x)
 | 
				
			||||||
 | 
					            return float('nan')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def nroot(x, n):
 | 
				
			||||||
 | 
					        """Calculate x**(1/n), then improve the answer."""
 | 
				
			||||||
 | 
					        # This uses math.pow() to calculate an initial guess for the root,
 | 
				
			||||||
 | 
					        # then uses the iterated nroot algorithm to improve it.
 | 
				
			||||||
 | 
					        #
 | 
				
			||||||
 | 
					        # By my testing, about 8% of the time the iterated algorithm ends
 | 
				
			||||||
 | 
					        # up converging to a result which is less accurate than the initial
 | 
				
			||||||
 | 
					        # guess. [FIXME: is this still true?] In that case, we use the
 | 
				
			||||||
 | 
					        # guess instead of the "improved" value. This way, we're never
 | 
				
			||||||
 | 
					        # less accurate than math.pow().
 | 
				
			||||||
 | 
					        r1 = math.pow(x, 1.0/n)
 | 
				
			||||||
 | 
					        eps1 = abs(r1**n - x)
 | 
				
			||||||
 | 
					        if eps1 == 0.0:
 | 
				
			||||||
 | 
					            # r1 is the exact root, so we're done. By my testing, this
 | 
				
			||||||
 | 
					            # occurs about 80% of the time for x < 1 and 30% of the
 | 
				
			||||||
 | 
					            # time for x > 1.
 | 
				
			||||||
 | 
					            return r1
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                r2 = _nroot_NS.iterated_nroot(x, n, r1)
 | 
				
			||||||
 | 
					            except RuntimeError:
 | 
				
			||||||
 | 
					                return r1
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                eps2 = abs(r2**n - x)
 | 
				
			||||||
 | 
					                if eps1 < eps2:
 | 
				
			||||||
 | 
					                    return r1
 | 
				
			||||||
 | 
					                return r2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def iterated_nroot(a, n, g):
 | 
				
			||||||
 | 
					        """Return the nth root of a, starting with guess g.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        This is a special case of Newton's Method.
 | 
				
			||||||
 | 
					        https://en.wikipedia.org/wiki/Nth_root_algorithm
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        np = n - 1
 | 
				
			||||||
 | 
					        def iterate(r):
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                return (np*r + a/math.pow(r, np))/n
 | 
				
			||||||
 | 
					            except OverflowError:
 | 
				
			||||||
 | 
					                # If r is large enough, r**np may overflow. If that
 | 
				
			||||||
 | 
					                # happens, r**-np will be small, but not necessarily zero.
 | 
				
			||||||
 | 
					                return (np*r + a*math.pow(r, -np))/n
 | 
				
			||||||
 | 
					        # With a good guess, such as g = a**(1/n), this will converge in
 | 
				
			||||||
 | 
					        # only a few iterations. However a poor guess can take thousands
 | 
				
			||||||
 | 
					        # of iterations to converge, if at all. We guard against poor
 | 
				
			||||||
 | 
					        # guesses by setting an upper limit to the number of iterations.
 | 
				
			||||||
 | 
					        r1 = g
 | 
				
			||||||
 | 
					        r2 = iterate(g)
 | 
				
			||||||
 | 
					        for i in range(1000):
 | 
				
			||||||
 | 
					            if r1 == r2:
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					            # Use Floyd's cycle-finding algorithm to avoid being trapped
 | 
				
			||||||
 | 
					            # in a cycle.
 | 
				
			||||||
 | 
					            # https://en.wikipedia.org/wiki/Cycle_detection#Tortoise_and_hare
 | 
				
			||||||
 | 
					            r1 = iterate(r1)
 | 
				
			||||||
 | 
					            r2 = iterate(iterate(r2))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # If the guess is particularly bad, the above may fail to
 | 
				
			||||||
 | 
					            # converge in any reasonable time.
 | 
				
			||||||
 | 
					            raise RuntimeError('nth-root failed to converge')
 | 
				
			||||||
 | 
					        return r2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def decimal_nroot(x, n):
 | 
				
			||||||
 | 
					        """Handle nth root of Decimals."""
 | 
				
			||||||
 | 
					        assert isinstance(x, decimal.Decimal)
 | 
				
			||||||
 | 
					        assert isinstance(n, int)
 | 
				
			||||||
 | 
					        if x.is_snan():
 | 
				
			||||||
 | 
					            # Signalling NANs always raise.
 | 
				
			||||||
 | 
					            raise decimal.InvalidOperation('nth-root of snan')
 | 
				
			||||||
 | 
					        if x.is_qnan():
 | 
				
			||||||
 | 
					            # Quiet NANs only raise if the context is set to raise,
 | 
				
			||||||
 | 
					            # otherwise return a NAN.
 | 
				
			||||||
 | 
					            ctx = decimal.getcontext()
 | 
				
			||||||
 | 
					            if ctx.traps[decimal.InvalidOperation]:
 | 
				
			||||||
 | 
					                raise decimal.InvalidOperation('nth-root of nan')
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                # Preserve the input NAN.
 | 
				
			||||||
 | 
					                return x
 | 
				
			||||||
 | 
					        if x.is_infinite():
 | 
				
			||||||
 | 
					            return x
 | 
				
			||||||
 | 
					        # FIXME this hasn't had the extensive testing of the float
 | 
				
			||||||
 | 
					        # version _iterated_nroot so there's possibly some buggy
 | 
				
			||||||
 | 
					        # corner cases buried in here. Can it overflow? Fail to
 | 
				
			||||||
 | 
					        # converge or get trapped in a cycle? Converge to a less
 | 
				
			||||||
 | 
					        # accurate root?
 | 
				
			||||||
 | 
					        np = n - 1
 | 
				
			||||||
 | 
					        def iterate(r):
 | 
				
			||||||
 | 
					            return (np*r + x/r**np)/n
 | 
				
			||||||
 | 
					        r0 = x**(decimal.Decimal(1)/n)
 | 
				
			||||||
 | 
					        assert isinstance(r0, decimal.Decimal)
 | 
				
			||||||
 | 
					        r1 = iterate(r0)
 | 
				
			||||||
 | 
					        while True:
 | 
				
			||||||
 | 
					            if r1 == r0:
 | 
				
			||||||
 | 
					                return r1
 | 
				
			||||||
 | 
					            r0, r1 = r1, iterate(r1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def bignum_nroot(x, n):
 | 
				
			||||||
 | 
					        """Return the nth root of a positive huge number."""
 | 
				
			||||||
 | 
					        assert x > 0
 | 
				
			||||||
 | 
					        # I state without proof that ⁿ√x ≈ ⁿ√2·ⁿ√(x//2)
 | 
				
			||||||
 | 
					        # and that for sufficiently big x the error is acceptible.
 | 
				
			||||||
 | 
					        # We now halve x until it is small enough to get the root.
 | 
				
			||||||
 | 
					        m = 0
 | 
				
			||||||
 | 
					        while True:
 | 
				
			||||||
 | 
					            x //= 2
 | 
				
			||||||
 | 
					            m += 1
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                y = float(x)
 | 
				
			||||||
 | 
					            except OverflowError:
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					            break
 | 
				
			||||||
 | 
					        a = _nroot_NS.nroot(y, n)
 | 
				
			||||||
 | 
					        # At this point, we want the nth-root of 2**m, or 2**(m/n).
 | 
				
			||||||
 | 
					        # We can write that as 2**(q + r/n) = 2**q * ⁿ√2**r where q = m//n.
 | 
				
			||||||
 | 
					        q, r = divmod(m, n)
 | 
				
			||||||
 | 
					        b = 2**q * _nroot_NS.nroot(2**r, n)
 | 
				
			||||||
 | 
					        return a * b
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# This is the (private) function for calculating nth roots:
 | 
				
			||||||
 | 
					_nth_root = _nroot_NS.nth_root
 | 
				
			||||||
 | 
					assert type(_nth_root) is type(lambda: None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _product(values):
 | 
				
			||||||
 | 
					    """Return product of values as (exponent, mantissa)."""
 | 
				
			||||||
 | 
					    errmsg = 'mixed Decimal and float is not supported'
 | 
				
			||||||
 | 
					    prod = 1
 | 
				
			||||||
 | 
					    for x in values:
 | 
				
			||||||
 | 
					        if isinstance(x, float):
 | 
				
			||||||
 | 
					            break
 | 
				
			||||||
 | 
					        prod *= x
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return (0, prod)
 | 
				
			||||||
 | 
					    if isinstance(prod, Decimal):
 | 
				
			||||||
 | 
					        raise TypeError(errmsg)
 | 
				
			||||||
 | 
					    # Since floats can overflow easily, we calculate the product as a
 | 
				
			||||||
 | 
					    # sort of poor-man's BigFloat. Given that:
 | 
				
			||||||
 | 
					    #
 | 
				
			||||||
 | 
					    #   x = 2**p * m  # p == power or exponent (scale), m = mantissa
 | 
				
			||||||
 | 
					    #
 | 
				
			||||||
 | 
					    # we can calculate the product of two (or more) x values as:
 | 
				
			||||||
 | 
					    #
 | 
				
			||||||
 | 
					    #   x1*x2 = 2**p1*m1 * 2**p2*m2 = 2**(p1+p2)*(m1*m2)
 | 
				
			||||||
 | 
					    #
 | 
				
			||||||
 | 
					    mant, scale = 1, 0  #math.frexp(prod)  # FIXME
 | 
				
			||||||
 | 
					    for y in chain([x], values):
 | 
				
			||||||
 | 
					        if isinstance(y, Decimal):
 | 
				
			||||||
 | 
					            raise TypeError(errmsg)
 | 
				
			||||||
 | 
					        m1, e1 = math.frexp(y)
 | 
				
			||||||
 | 
					        m2, e2 = math.frexp(mant)
 | 
				
			||||||
 | 
					        scale += (e1 + e2)
 | 
				
			||||||
 | 
					        mant = m1*m2
 | 
				
			||||||
 | 
					    return (scale, mant)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# === Measures of central tendency (averages) ===
 | 
					# === Measures of central tendency (averages) ===
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def mean(data):
 | 
					def mean(data):
 | 
				
			||||||
| 
						 | 
					@ -331,6 +555,49 @@ def mean(data):
 | 
				
			||||||
    return _convert(total/n, T)
 | 
					    return _convert(total/n, T)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def geometric_mean(data):
 | 
				
			||||||
 | 
					    """Return the geometric mean of data.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    The geometric mean is appropriate when averaging quantities which
 | 
				
			||||||
 | 
					    are multiplied together rather than added, for example growth rates.
 | 
				
			||||||
 | 
					    Suppose an investment grows by 10% in the first year, falls by 5% in
 | 
				
			||||||
 | 
					    the second, then grows by 12% in the third, what is the average rate
 | 
				
			||||||
 | 
					    of growth over the three years?
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    >>> geometric_mean([1.10, 0.95, 1.12])
 | 
				
			||||||
 | 
					    1.0538483123382172
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    giving an average growth of 5.385%. Using the arithmetic mean will
 | 
				
			||||||
 | 
					    give approximately 5.667%, which is too high.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ``StatisticsError`` will be raised if ``data`` is empty, or any
 | 
				
			||||||
 | 
					    element is less than zero.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    if iter(data) is data:
 | 
				
			||||||
 | 
					        data = list(data)
 | 
				
			||||||
 | 
					    errmsg = 'geometric mean does not support negative values'
 | 
				
			||||||
 | 
					    n = len(data)
 | 
				
			||||||
 | 
					    if n < 1:
 | 
				
			||||||
 | 
					        raise StatisticsError('geometric_mean requires at least one data point')
 | 
				
			||||||
 | 
					    elif n == 1:
 | 
				
			||||||
 | 
					        x = data[0]
 | 
				
			||||||
 | 
					        if isinstance(g, (numbers.Real, Decimal)):
 | 
				
			||||||
 | 
					            if x < 0:
 | 
				
			||||||
 | 
					                raise StatisticsError(errmsg)
 | 
				
			||||||
 | 
					            return x
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise TypeError('unsupported type')
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        scale, prod = _product(_fail_neg(data, errmsg))
 | 
				
			||||||
 | 
					        r = _nth_root(prod, n)
 | 
				
			||||||
 | 
					        if scale:
 | 
				
			||||||
 | 
					            p, q = divmod(scale, n)
 | 
				
			||||||
 | 
					            s = 2**p * _nth_root(2**q, n)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            s = 1
 | 
				
			||||||
 | 
					        return s*r
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def harmonic_mean(data):
 | 
					def harmonic_mean(data):
 | 
				
			||||||
    """Return the harmonic mean of data.
 | 
					    """Return the harmonic mean of data.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1010,6 +1010,291 @@ def test_error_msg(self):
 | 
				
			||||||
        self.assertEqual(errmsg, msg)
 | 
					        self.assertEqual(errmsg, msg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Test_Product(NumericTestCase):
 | 
				
			||||||
 | 
					    """Test the private _product function."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_ints(self):
 | 
				
			||||||
 | 
					        data = [1, 2, 5, 7, 9]
 | 
				
			||||||
 | 
					        self.assertEqual(statistics._product(data), (0, 630))
 | 
				
			||||||
 | 
					        self.assertEqual(statistics._product(data*100), (0, 630**100))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_floats(self):
 | 
				
			||||||
 | 
					        data = [1.0, 2.0, 4.0, 8.0]
 | 
				
			||||||
 | 
					        self.assertEqual(statistics._product(data), (8, 0.25))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_overflow(self):
 | 
				
			||||||
 | 
					        # Test with floats that overflow.
 | 
				
			||||||
 | 
					        data = [1e300]*5
 | 
				
			||||||
 | 
					        self.assertEqual(statistics._product(data), (5980, 0.6928287951283193))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_fractions(self):
 | 
				
			||||||
 | 
					        F = Fraction
 | 
				
			||||||
 | 
					        data = [F(14, 23), F(69, 1), F(665, 529), F(299, 105), F(1683, 39)]
 | 
				
			||||||
 | 
					        exp, mant = statistics._product(data)
 | 
				
			||||||
 | 
					        self.assertEqual(exp, 0)
 | 
				
			||||||
 | 
					        self.assertEqual(mant, F(2*3*7*11*17*19, 23))
 | 
				
			||||||
 | 
					        self.assertTrue(isinstance(mant, F))
 | 
				
			||||||
 | 
					        # Mixed Fraction and int.
 | 
				
			||||||
 | 
					        data = [3, 25, F(2, 15)]
 | 
				
			||||||
 | 
					        exp, mant = statistics._product(data)
 | 
				
			||||||
 | 
					        self.assertEqual(exp, 0)
 | 
				
			||||||
 | 
					        self.assertEqual(mant, F(10))
 | 
				
			||||||
 | 
					        self.assertTrue(isinstance(mant, F))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @unittest.expectedFailure
 | 
				
			||||||
 | 
					    def test_decimal(self):
 | 
				
			||||||
 | 
					        D = Decimal
 | 
				
			||||||
 | 
					        data = [D('24.5'), D('17.6'), D('0.025'), D('1.3')]
 | 
				
			||||||
 | 
					        assert False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_mixed_decimal_float(self):
 | 
				
			||||||
 | 
					        # Test that mixed Decimal and float raises.
 | 
				
			||||||
 | 
					        self.assertRaises(TypeError, statistics._product, [1.0, Decimal(1)])
 | 
				
			||||||
 | 
					        self.assertRaises(TypeError, statistics._product, [Decimal(1), 1.0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Test_Nth_Root(NumericTestCase):
 | 
				
			||||||
 | 
					    """Test the functionality of the private _nth_root function."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        self.nroot = statistics._nth_root
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # --- Special values (infinities, NANs, zeroes) ---
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_float_NAN(self):
 | 
				
			||||||
 | 
					        # Test that the root of a float NAN is a float NAN.
 | 
				
			||||||
 | 
					        NAN = float('nan')
 | 
				
			||||||
 | 
					        for n in range(2, 9):
 | 
				
			||||||
 | 
					            with self.subTest(n=n):
 | 
				
			||||||
 | 
					                result = self.nroot(NAN, n)
 | 
				
			||||||
 | 
					                self.assertTrue(math.isnan(result))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_decimal_QNAN(self):
 | 
				
			||||||
 | 
					        # Test the  behaviour when taking the root of a Decimal quiet NAN.
 | 
				
			||||||
 | 
					        NAN = decimal.Decimal('nan')
 | 
				
			||||||
 | 
					        with decimal.localcontext() as ctx:
 | 
				
			||||||
 | 
					            ctx.traps[decimal.InvalidOperation] = 1
 | 
				
			||||||
 | 
					            self.assertRaises(decimal.InvalidOperation, self.nroot, NAN, 5)
 | 
				
			||||||
 | 
					            ctx.traps[decimal.InvalidOperation] = 0
 | 
				
			||||||
 | 
					            self.assertTrue(self.nroot(NAN, 5).is_qnan())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_decimal_SNAN(self):
 | 
				
			||||||
 | 
					        # Test that taking the root of a Decimal sNAN always raises.
 | 
				
			||||||
 | 
					        sNAN = decimal.Decimal('snan')
 | 
				
			||||||
 | 
					        with decimal.localcontext() as ctx:
 | 
				
			||||||
 | 
					            ctx.traps[decimal.InvalidOperation] = 1
 | 
				
			||||||
 | 
					            self.assertRaises(decimal.InvalidOperation, self.nroot, sNAN, 5)
 | 
				
			||||||
 | 
					            ctx.traps[decimal.InvalidOperation] = 0
 | 
				
			||||||
 | 
					            self.assertRaises(decimal.InvalidOperation, self.nroot, sNAN, 5)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_inf(self):
 | 
				
			||||||
 | 
					        # Test that the root of infinity is infinity.
 | 
				
			||||||
 | 
					        for INF in (float('inf'), decimal.Decimal('inf')):
 | 
				
			||||||
 | 
					            for n in range(2, 9):
 | 
				
			||||||
 | 
					                with self.subTest(n=n, inf=INF):
 | 
				
			||||||
 | 
					                    self.assertEqual(self.nroot(INF, n), INF)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def testNInf(self):
 | 
				
			||||||
 | 
					        # Test that the root of -inf is -inf for odd n.
 | 
				
			||||||
 | 
					        for NINF in (float('-inf'), decimal.Decimal('-inf')):
 | 
				
			||||||
 | 
					            for n in range(3, 11, 2):
 | 
				
			||||||
 | 
					                with self.subTest(n=n, inf=NINF):
 | 
				
			||||||
 | 
					                    self.assertEqual(self.nroot(NINF, n), NINF)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # FIXME: need to check Decimal zeroes too.
 | 
				
			||||||
 | 
					    def test_zero(self):
 | 
				
			||||||
 | 
					        # Test that the root of +0.0 is +0.0.
 | 
				
			||||||
 | 
					        for n in range(2, 11):
 | 
				
			||||||
 | 
					            with self.subTest(n=n):
 | 
				
			||||||
 | 
					                result = self.nroot(+0.0, n)
 | 
				
			||||||
 | 
					                self.assertEqual(result, 0.0)
 | 
				
			||||||
 | 
					                self.assertEqual(sign(result), +1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # FIXME: need to check Decimal zeroes too.
 | 
				
			||||||
 | 
					    def test_neg_zero(self):
 | 
				
			||||||
 | 
					        # Test that the root of -0.0 is -0.0.
 | 
				
			||||||
 | 
					        for n in range(2, 11):
 | 
				
			||||||
 | 
					            with self.subTest(n=n):
 | 
				
			||||||
 | 
					                result = self.nroot(-0.0, n)
 | 
				
			||||||
 | 
					                self.assertEqual(result, 0.0)
 | 
				
			||||||
 | 
					                self.assertEqual(sign(result), -1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # --- Test return types ---
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def check_result_type(self, x, n, outtype):
 | 
				
			||||||
 | 
					        self.assertIsInstance(self.nroot(x, n), outtype)
 | 
				
			||||||
 | 
					        class MySubclass(type(x)):
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					        self.assertIsInstance(self.nroot(MySubclass(x), n), outtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def testDecimal(self):
 | 
				
			||||||
 | 
					        # Test that Decimal arguments return Decimal results.
 | 
				
			||||||
 | 
					        self.check_result_type(decimal.Decimal('33.3'), 3, decimal.Decimal)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def testFloat(self):
 | 
				
			||||||
 | 
					        # Test that other arguments return float results.
 | 
				
			||||||
 | 
					        for x in (0.2, Fraction(11, 7), 91):
 | 
				
			||||||
 | 
					            self.check_result_type(x, 6, float)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # --- Test bad input ---
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def testBadOrderTypes(self):
 | 
				
			||||||
 | 
					        # Test that nroot raises correctly when n has the wrong type.
 | 
				
			||||||
 | 
					        for n in (5.0, 2j, None, 'x', b'x', [], {}, set(), sign):
 | 
				
			||||||
 | 
					            with self.subTest(n=n):
 | 
				
			||||||
 | 
					                self.assertRaises(TypeError, self.nroot, 2.5, n)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def testBadOrderValues(self):
 | 
				
			||||||
 | 
					        # Test that nroot raises correctly when n has a wrong value.
 | 
				
			||||||
 | 
					        for n in (1, 0, -1, -2, -87):
 | 
				
			||||||
 | 
					            with self.subTest(n=n):
 | 
				
			||||||
 | 
					                self.assertRaises(ValueError, self.nroot, 2.5, n)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def testBadTypes(self):
 | 
				
			||||||
 | 
					        # Test that nroot raises correctly when x has the wrong type.
 | 
				
			||||||
 | 
					        for x in (None, 'x', b'x', [], {}, set(), sign):
 | 
				
			||||||
 | 
					            with self.subTest(x=x):
 | 
				
			||||||
 | 
					                self.assertRaises(TypeError, self.nroot, x, 3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def testNegativeEvenPower(self):
 | 
				
			||||||
 | 
					        # Test negative x with even n raises correctly.
 | 
				
			||||||
 | 
					        x = random.uniform(-20.0, -0.1)
 | 
				
			||||||
 | 
					        assert x < 0
 | 
				
			||||||
 | 
					        for n in range(2, 9, 2):
 | 
				
			||||||
 | 
					            with self.subTest(x=x, n=n):
 | 
				
			||||||
 | 
					                self.assertRaises(ValueError, self.nroot, x, n)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # --- Test that nroot is never worse than calling math.pow() ---
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def check_error_is_no_worse(self, x, n):
 | 
				
			||||||
 | 
					        y = math.pow(x, n)
 | 
				
			||||||
 | 
					        with self.subTest(x=x, n=n, y=y):
 | 
				
			||||||
 | 
					            err1 = abs(self.nroot(y, n) - x)
 | 
				
			||||||
 | 
					            err2 = abs(math.pow(y, 1.0/n) - x)
 | 
				
			||||||
 | 
					            self.assertLessEqual(err1, err2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def testCompareWithPowSmall(self):
 | 
				
			||||||
 | 
					        # Compare nroot with pow for small values of x.
 | 
				
			||||||
 | 
					        for i in range(200):
 | 
				
			||||||
 | 
					            x = random.uniform(1e-9, 1.0-1e-9)
 | 
				
			||||||
 | 
					            n = random.choice(range(2, 16))
 | 
				
			||||||
 | 
					            self.check_error_is_no_worse(x, n)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def testCompareWithPowMedium(self):
 | 
				
			||||||
 | 
					        # Compare nroot with pow for medium-sized values of x.
 | 
				
			||||||
 | 
					        for i in range(200):
 | 
				
			||||||
 | 
					            x = random.uniform(1.0, 100.0)
 | 
				
			||||||
 | 
					            n = random.choice(range(2, 16))
 | 
				
			||||||
 | 
					            self.check_error_is_no_worse(x, n)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def testCompareWithPowLarge(self):
 | 
				
			||||||
 | 
					        # Compare nroot with pow for largish values of x.
 | 
				
			||||||
 | 
					        for i in range(200):
 | 
				
			||||||
 | 
					            x = random.uniform(100.0, 10000.0)
 | 
				
			||||||
 | 
					            n = random.choice(range(2, 16))
 | 
				
			||||||
 | 
					            self.check_error_is_no_worse(x, n)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def testCompareWithPowHuge(self):
 | 
				
			||||||
 | 
					        # Compare nroot with pow for huge values of x.
 | 
				
			||||||
 | 
					        for i in range(200):
 | 
				
			||||||
 | 
					            x = random.uniform(1e20, 1e50)
 | 
				
			||||||
 | 
					            # We restrict the order here to avoid an Overflow error.
 | 
				
			||||||
 | 
					            n = random.choice(range(2, 7))
 | 
				
			||||||
 | 
					            self.check_error_is_no_worse(x, n)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # --- Test for numerically correct answers ---
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def testExactPowers(self):
 | 
				
			||||||
 | 
					        # Test that small integer powers are calculated exactly.
 | 
				
			||||||
 | 
					        for i in range(1, 51):
 | 
				
			||||||
 | 
					            for n in range(2, 16):
 | 
				
			||||||
 | 
					                if (i, n) == (35, 13):
 | 
				
			||||||
 | 
					                    # See testExpectedFailure35p13
 | 
				
			||||||
 | 
					                    continue
 | 
				
			||||||
 | 
					                with self.subTest(i=i, n=n):
 | 
				
			||||||
 | 
					                    x = i**n
 | 
				
			||||||
 | 
					                    self.assertEqual(self.nroot(x, n), i)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def testExactPowersNegatives(self):
 | 
				
			||||||
 | 
					        # Test that small negative integer powers are calculated exactly.
 | 
				
			||||||
 | 
					        for i in range(-1, -51, -1):
 | 
				
			||||||
 | 
					            for n in range(3, 16, 2):
 | 
				
			||||||
 | 
					                if (i, n) == (-35, 13):
 | 
				
			||||||
 | 
					                    # See testExpectedFailure35p13
 | 
				
			||||||
 | 
					                    continue
 | 
				
			||||||
 | 
					                with self.subTest(i=i, n=n):
 | 
				
			||||||
 | 
					                    x = i**n
 | 
				
			||||||
 | 
					                    assert sign(x) == -1
 | 
				
			||||||
 | 
					                    self.assertEqual(self.nroot(x, n), i)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def testExpectedFailure35p13(self):
 | 
				
			||||||
 | 
					        # Test the expected failure 35**13 is almost exact.
 | 
				
			||||||
 | 
					        x = 35**13
 | 
				
			||||||
 | 
					        err = abs(self.nroot(x, 13) - 35)
 | 
				
			||||||
 | 
					        self.assertLessEqual(err, 0.000000001)
 | 
				
			||||||
 | 
					        err = abs(self.nroot(-x, 13) + 35)
 | 
				
			||||||
 | 
					        self.assertLessEqual(err, 0.000000001)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def testOne(self):
 | 
				
			||||||
 | 
					        # Test that the root of 1.0 is 1.0.
 | 
				
			||||||
 | 
					        for n in range(2, 11):
 | 
				
			||||||
 | 
					            with self.subTest(n=n):
 | 
				
			||||||
 | 
					                self.assertEqual(self.nroot(1.0, n), 1.0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def testFraction(self):
 | 
				
			||||||
 | 
					        # Test Fraction results.
 | 
				
			||||||
 | 
					        x = Fraction(89, 75)
 | 
				
			||||||
 | 
					        self.assertEqual(self.nroot(x**12, 12), float(x))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def testInt(self):
 | 
				
			||||||
 | 
					        # Test int results.
 | 
				
			||||||
 | 
					        x = 276
 | 
				
			||||||
 | 
					        self.assertEqual(self.nroot(x**24, 24), x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def testBigInt(self):
 | 
				
			||||||
 | 
					        # Test that ints too big to convert to floats work.
 | 
				
			||||||
 | 
					        bignum = 10**20  # That's not that big...
 | 
				
			||||||
 | 
					        self.assertEqual(self.nroot(bignum**280, 280), bignum)
 | 
				
			||||||
 | 
					        # Can we make it bigger?
 | 
				
			||||||
 | 
					        hugenum = bignum**50
 | 
				
			||||||
 | 
					        # Make sure that it is too big to convert to a float.
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            y = float(hugenum)
 | 
				
			||||||
 | 
					        except OverflowError:
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise AssertionError('hugenum is not big enough')
 | 
				
			||||||
 | 
					        self.assertEqual(self.nroot(hugenum, 50), float(bignum))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def testDecimal(self):
 | 
				
			||||||
 | 
					        # Test Decimal results.
 | 
				
			||||||
 | 
					        for s in '3.759 64.027 5234.338'.split():
 | 
				
			||||||
 | 
					            x = decimal.Decimal(s)
 | 
				
			||||||
 | 
					            with self.subTest(x=x):
 | 
				
			||||||
 | 
					                a = self.nroot(x**5, 5)
 | 
				
			||||||
 | 
					                self.assertEqual(a, x)
 | 
				
			||||||
 | 
					                a = self.nroot(x**17, 17)
 | 
				
			||||||
 | 
					                self.assertEqual(a, x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def testFloat(self):
 | 
				
			||||||
 | 
					        # Test float results.
 | 
				
			||||||
 | 
					        for x in (3.04e-16, 18.25, 461.3, 1.9e17):
 | 
				
			||||||
 | 
					            with self.subTest(x=x):
 | 
				
			||||||
 | 
					                self.assertEqual(self.nroot(x**3, 3), x)
 | 
				
			||||||
 | 
					                self.assertEqual(self.nroot(x**8, 8), x)
 | 
				
			||||||
 | 
					                self.assertEqual(self.nroot(x**11, 11), x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Test_NthRoot_NS(unittest.TestCase):
 | 
				
			||||||
 | 
					    """Test internals of the nth_root function, hidden in _nroot_NS."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_class_cannot_be_instantiated(self):
 | 
				
			||||||
 | 
					        # Test that _nroot_NS cannot be instantiated.
 | 
				
			||||||
 | 
					        # It should be a namespace, like in C++ or C#, but Python
 | 
				
			||||||
 | 
					        # lacks that feature and so we have to make do with a class.
 | 
				
			||||||
 | 
					        self.assertRaises(TypeError, statistics._nroot_NS)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# === Tests for public functions ===
 | 
					# === Tests for public functions ===
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class UnivariateCommonMixin:
 | 
					class UnivariateCommonMixin:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue