diff options
Diffstat (limited to 'Lib/statistics.py')
-rw-r--r-- | Lib/statistics.py | 418 |
1 files changed, 359 insertions, 59 deletions
diff --git a/Lib/statistics.py b/Lib/statistics.py index 4f5c1c164a..7d53e0c0e2 100644 --- a/Lib/statistics.py +++ b/Lib/statistics.py @@ -1,20 +1,3 @@ -## Module statistics.py -## -## Copyright (c) 2013 Steven D'Aprano <steve+python@pearwood.info>. -## -## Licensed under the Apache License, Version 2.0 (the "License"); -## you may not use this file except in compliance with the License. -## You may obtain a copy of the License at -## -## http://www.apache.org/licenses/LICENSE-2.0 -## -## Unless required by applicable law or agreed to in writing, software -## distributed under the License is distributed on an "AS IS" BASIS, -## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -## See the License for the specific language governing permissions and -## limitations under the License. - - """ Basic statistics module. @@ -28,6 +11,8 @@ Calculating averages Function Description ================== ============================================= mean Arithmetic mean (average) of data. +geometric_mean Geometric mean of data. +harmonic_mean Harmonic mean of data. median Median (middle value) of data. median_low Low median of data. median_high High median of data. @@ -95,16 +80,18 @@ A single exception is defined: StatisticsError is a subclass of ValueError. __all__ = [ 'StatisticsError', 'pstdev', 'pvariance', 'stdev', 'variance', 'median', 'median_low', 'median_high', 'median_grouped', - 'mean', 'mode', + 'mean', 'mode', 'geometric_mean', 'harmonic_mean', ] - import collections +import decimal import math +import numbers from fractions import Fraction from decimal import Decimal -from itertools import groupby +from itertools import groupby, chain +from bisect import bisect_left, bisect_right @@ -134,7 +121,8 @@ def _sum(data, start=0): Some sources of round-off error will be avoided: - >>> _sum([1e50, 1, -1e50] * 1000) # Built-in sum returns zero. + # Built-in sum returns zero. + >>> _sum([1e50, 1, -1e50] * 1000) (<class 'float'>, Fraction(1000, 1), 3000) Fractions and Decimals are also supported: @@ -223,56 +211,26 @@ def _exact_ratio(x): # Optimise the common case of floats. We expect that the most often # used numeric type will be builtin floats, so try to make this as # fast as possible. - if type(x) is float: + if type(x) is float or type(x) is Decimal: return x.as_integer_ratio() try: # x may be an int, Fraction, or Integral ABC. return (x.numerator, x.denominator) except AttributeError: try: - # x may be a float subclass. + # x may be a float or Decimal subclass. return x.as_integer_ratio() except AttributeError: - try: - # x may be a Decimal. - return _decimal_to_ratio(x) - except AttributeError: - # Just give up? - pass + # Just give up? + pass except (OverflowError, ValueError): # float NAN or INF. - assert not math.isfinite(x) + assert not _isfinite(x) return (x, None) msg = "can't convert type '{}' to numerator/denominator" raise TypeError(msg.format(type(x).__name__)) -# FIXME This is faster than Fraction.from_decimal, but still too slow. -def _decimal_to_ratio(d): - """Convert Decimal d to exact integer ratio (numerator, denominator). - - >>> from decimal import Decimal - >>> _decimal_to_ratio(Decimal("2.6")) - (26, 10) - - """ - sign, digits, exp = d.as_tuple() - if exp in ('F', 'n', 'N'): # INF, NAN, sNAN - assert not d.is_finite() - return (d, None) - num = 0 - for digit in digits: - num = num*10 + digit - if exp < 0: - den = 10**-exp - else: - num *= 10**exp - den = 1 - if sign: - num = -num - return (num, den) - - def _convert(value, T): """Convert value to given numeric type T.""" if type(value) is T: @@ -305,6 +263,253 @@ def _counts(data): return table +def _find_lteq(a, x): + 'Locate the leftmost value exactly equal to x' + i = bisect_left(a, x) + if i != len(a) and a[i] == x: + return i + raise ValueError + + +def _find_rteq(a, l, x): + 'Locate the rightmost value exactly equal to x' + i = bisect_right(a, x, lo=l) + if i != (len(a)+1) and a[i-1] == x: + return i-1 + raise ValueError + + +def _fail_neg(values, errmsg='negative value'): + """Iterate over values, failing if any are less than zero.""" + for x in values: + if x < 0: + raise StatisticsError(errmsg) + yield x + + +class _nroot_NS: + """Hands off! Don't touch! + + Everything inside this namespace (class) is an even-more-private + implementation detail of the private _nth_root function. + """ + # This class exists only to be used as a namespace, for convenience + # of being able to keep the related functions together, and to + # collapse the group in an editor. If this were C# or C++, I would + # use a Namespace, but the closest Python has is a class. + # + # FIXME possibly move this out into a separate module? + # That feels like overkill, and may encourage people to treat it as + # a public feature. + def __init__(self): + raise TypeError('namespace only, do not instantiate') + + def nth_root(x, n): + """Return the positive nth root of numeric x. + + This may be more accurate than ** or pow(): + + >>> math.pow(1000, 1.0/3) #doctest:+SKIP + 9.999999999999998 + + >>> _nth_root(1000, 3) + 10.0 + >>> _nth_root(11**5, 5) + 11.0 + >>> _nth_root(2, 12) + 1.0594630943592953 + + """ + if not isinstance(n, int): + raise TypeError('degree n must be an int') + if n < 2: + raise ValueError('degree n must be 2 or more') + if isinstance(x, decimal.Decimal): + return _nroot_NS.decimal_nroot(x, n) + elif isinstance(x, numbers.Real): + return _nroot_NS.float_nroot(x, n) + else: + raise TypeError('expected a number, got %s') % type(x).__name__ + + def float_nroot(x, n): + """Handle nth root of Reals, treated as a float.""" + assert isinstance(n, int) and n > 1 + if x < 0: + raise ValueError('domain error: root of negative number') + elif x == 0: + return math.copysign(0.0, x) + elif x > 0: + try: + isinfinity = math.isinf(x) + except OverflowError: + return _nroot_NS.bignum_nroot(x, n) + else: + if isinfinity: + return float('inf') + else: + return _nroot_NS.nroot(x, n) + else: + assert math.isnan(x) + return float('nan') + + def nroot(x, n): + """Calculate x**(1/n), then improve the answer.""" + # This uses math.pow() to calculate an initial guess for the root, + # then uses the iterated nroot algorithm to improve it. + # + # By my testing, about 8% of the time the iterated algorithm ends + # up converging to a result which is less accurate than the initial + # guess. [FIXME: is this still true?] In that case, we use the + # guess instead of the "improved" value. This way, we're never + # less accurate than math.pow(). + r1 = math.pow(x, 1.0/n) + eps1 = abs(r1**n - x) + if eps1 == 0.0: + # r1 is the exact root, so we're done. By my testing, this + # occurs about 80% of the time for x < 1 and 30% of the + # time for x > 1. + return r1 + else: + try: + r2 = _nroot_NS.iterated_nroot(x, n, r1) + except RuntimeError: + return r1 + else: + eps2 = abs(r2**n - x) + if eps1 < eps2: + return r1 + return r2 + + def iterated_nroot(a, n, g): + """Return the nth root of a, starting with guess g. + + This is a special case of Newton's Method. + https://en.wikipedia.org/wiki/Nth_root_algorithm + """ + np = n - 1 + def iterate(r): + try: + return (np*r + a/math.pow(r, np))/n + except OverflowError: + # If r is large enough, r**np may overflow. If that + # happens, r**-np will be small, but not necessarily zero. + return (np*r + a*math.pow(r, -np))/n + # With a good guess, such as g = a**(1/n), this will converge in + # only a few iterations. However a poor guess can take thousands + # of iterations to converge, if at all. We guard against poor + # guesses by setting an upper limit to the number of iterations. + r1 = g + r2 = iterate(g) + for i in range(1000): + if r1 == r2: + break + # Use Floyd's cycle-finding algorithm to avoid being trapped + # in a cycle. + # https://en.wikipedia.org/wiki/Cycle_detection#Tortoise_and_hare + r1 = iterate(r1) + r2 = iterate(iterate(r2)) + else: + # If the guess is particularly bad, the above may fail to + # converge in any reasonable time. + raise RuntimeError('nth-root failed to converge') + return r2 + + def decimal_nroot(x, n): + """Handle nth root of Decimals.""" + assert isinstance(x, decimal.Decimal) + assert isinstance(n, int) + if x.is_snan(): + # Signalling NANs always raise. + raise decimal.InvalidOperation('nth-root of snan') + if x.is_qnan(): + # Quiet NANs only raise if the context is set to raise, + # otherwise return a NAN. + ctx = decimal.getcontext() + if ctx.traps[decimal.InvalidOperation]: + raise decimal.InvalidOperation('nth-root of nan') + else: + # Preserve the input NAN. + return x + if x < 0: + raise ValueError('domain error: root of negative number') + if x.is_infinite(): + return x + # FIXME this hasn't had the extensive testing of the float + # version _iterated_nroot so there's possibly some buggy + # corner cases buried in here. Can it overflow? Fail to + # converge or get trapped in a cycle? Converge to a less + # accurate root? + np = n - 1 + def iterate(r): + return (np*r + x/r**np)/n + r0 = x**(decimal.Decimal(1)/n) + assert isinstance(r0, decimal.Decimal) + r1 = iterate(r0) + while True: + if r1 == r0: + return r1 + r0, r1 = r1, iterate(r1) + + def bignum_nroot(x, n): + """Return the nth root of a positive huge number.""" + assert x > 0 + # I state without proof that ⁿ√x ≈ ⁿ√2·ⁿ√(x//2) + # and that for sufficiently big x the error is acceptable. + # We now halve x until it is small enough to get the root. + m = 0 + while True: + x //= 2 + m += 1 + try: + y = float(x) + except OverflowError: + continue + break + a = _nroot_NS.nroot(y, n) + # At this point, we want the nth-root of 2**m, or 2**(m/n). + # We can write that as 2**(q + r/n) = 2**q * ⁿ√2**r where q = m//n. + q, r = divmod(m, n) + b = 2**q * _nroot_NS.nroot(2**r, n) + return a * b + + +# This is the (private) function for calculating nth roots: +_nth_root = _nroot_NS.nth_root +assert type(_nth_root) is type(lambda: None) + + +def _product(values): + """Return product of values as (exponent, mantissa).""" + errmsg = 'mixed Decimal and float is not supported' + prod = 1 + for x in values: + if isinstance(x, float): + break + prod *= x + else: + return (0, prod) + if isinstance(prod, Decimal): + raise TypeError(errmsg) + # Since floats can overflow easily, we calculate the product as a + # sort of poor-man's BigFloat. Given that: + # + # x = 2**p * m # p == power or exponent (scale), m = mantissa + # + # we can calculate the product of two (or more) x values as: + # + # x1*x2 = 2**p1*m1 * 2**p2*m2 = 2**(p1+p2)*(m1*m2) + # + mant, scale = 1, 0 #math.frexp(prod) # FIXME + for y in chain([x], values): + if isinstance(y, Decimal): + raise TypeError(errmsg) + m1, e1 = math.frexp(y) + m2, e2 = math.frexp(mant) + scale += (e1 + e2) + mant = m1*m2 + return (scale, mant) + + # === Measures of central tendency (averages) === def mean(data): @@ -333,6 +538,95 @@ def mean(data): return _convert(total/n, T) +def geometric_mean(data): + """Return the geometric mean of data. + + The geometric mean is appropriate when averaging quantities which + are multiplied together rather than added, for example growth rates. + Suppose an investment grows by 10% in the first year, falls by 5% in + the second, then grows by 12% in the third, what is the average rate + of growth over the three years? + + >>> geometric_mean([1.10, 0.95, 1.12]) + 1.0538483123382172 + + giving an average growth of 5.385%. Using the arithmetic mean will + give approximately 5.667%, which is too high. + + ``StatisticsError`` will be raised if ``data`` is empty, or any + element is less than zero. + """ + if iter(data) is data: + data = list(data) + errmsg = 'geometric mean does not support negative values' + n = len(data) + if n < 1: + raise StatisticsError('geometric_mean requires at least one data point') + elif n == 1: + x = data[0] + if isinstance(g, (numbers.Real, Decimal)): + if x < 0: + raise StatisticsError(errmsg) + return x + else: + raise TypeError('unsupported type') + else: + scale, prod = _product(_fail_neg(data, errmsg)) + r = _nth_root(prod, n) + if scale: + p, q = divmod(scale, n) + s = 2**p * _nth_root(2**q, n) + else: + s = 1 + return s*r + + +def harmonic_mean(data): + """Return the harmonic mean of data. + + The harmonic mean, sometimes called the subcontrary mean, is the + reciprocal of the arithmetic mean of the reciprocals of the data, + and is often appropriate when averaging quantities which are rates + or ratios, for example speeds. Example: + + Suppose an investor purchases an equal value of shares in each of + three companies, with P/E (price/earning) ratios of 2.5, 3 and 10. + What is the average P/E ratio for the investor's portfolio? + + >>> harmonic_mean([2.5, 3, 10]) # For an equal investment portfolio. + 3.6 + + Using the arithmetic mean would give an average of about 5.167, which + is too high. + + If ``data`` is empty, or any element is less than zero, + ``harmonic_mean`` will raise ``StatisticsError``. + """ + # For a justification for using harmonic mean for P/E ratios, see + # http://fixthepitch.pellucid.com/comps-analysis-the-missing-harmony-of-summary-statistics/ + # http://papers.ssrn.com/sol3/papers.cfm?abstract_id=2621087 + if iter(data) is data: + data = list(data) + errmsg = 'harmonic mean does not support negative values' + n = len(data) + if n < 1: + raise StatisticsError('harmonic_mean requires at least one data point') + elif n == 1: + x = data[0] + if isinstance(x, (numbers.Real, Decimal)): + if x < 0: + raise StatisticsError(errmsg) + return x + else: + raise TypeError('unsupported type') + try: + T, total, count = _sum(1/x for x in _fail_neg(data, errmsg)) + except ZeroDivisionError: + return 0 + assert count == n + return _convert(n/total, T) + + # FIXME: investigate ways to calculate medians without sorting? Quickselect? def median(data): """Return the median (middle value) of numeric data. @@ -442,9 +736,15 @@ def median_grouped(data, interval=1): except TypeError: # Mixed type. For now we just coerce to float. L = float(x) - float(interval)/2 - cf = data.index(x) # Number of values below the median interval. - # FIXME The following line could be more efficient for big lists. - f = data.count(x) # Number of data points in the median interval. + + # Uses bisection search to search for x in data with log(n) time complexity + # Find the position of leftmost occurrence of x in data + l1 = _find_lteq(data, x) + # Find the position of rightmost occurrence of x in data[l1...len(data)] + # Assuming always l1 <= l2 + l2 = _find_rteq(data, l1, x) + cf = l1 + f = l2 - l1 + 1 return L + interval*(n/2 - cf)/f |