summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorTim Hochberg <tim_hochberg@local>2006-10-12 19:19:04 +0000
committerTim Hochberg <tim_hochberg@local>2006-10-12 19:19:04 +0000
commitabb7a32a344bf73dd5f9a878c99335352316480d (patch)
treee3c075733c0f02ac10f342f38b336a2bb1be7dbb /numpy/core
parent61d36f3bcbb50da5d33cf6c614859925e27f6abd (diff)
downloadnumpy-abb7a32a344bf73dd5f9a878c99335352316480d.tar.gz
Added docstring and tests to errstate. Also added 'all' option for seterr so that we can set all the options at once. Note that tests on errstate are only run in Python 2.5 and higher.
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/numeric.py47
-rw-r--r--numpy/core/tests/test_numeric.py7
2 files changed, 48 insertions, 6 deletions
diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py
index 0a211ff61..02b931f90 100644
--- a/numpy/core/numeric.py
+++ b/numpy/core/numeric.py
@@ -617,11 +617,14 @@ for key in _errdict.keys():
_errdict_rev[_errdict[key]] = key
del key
-def seterr(divide=None, over=None, under=None, invalid=None):
+def seterr(all=None, divide=None, over=None, under=None, invalid=None):
"""Set how floating-point errors are handled.
Valid values for each type of error are the strings
"ignore", "warn", "raise", and "call". Returns the old settings.
+ If 'all' is specified, values that are not otherwise specified
+ will be set to 'all', otherwise they will retain their old
+ values.
Note that operations on integer scalar types (such as int16) are
handled like floating point, and are affected by these settings.
@@ -630,19 +633,24 @@ def seterr(divide=None, over=None, under=None, invalid=None):
>>> seterr(over='raise')
{'over': 'ignore', 'divide': 'ignore', 'invalid': 'ignore', 'under': 'ignore'}
+ >>> seterr(all='warn', over='raise')
+ {'over': 'raise', 'divide': 'ignore', 'invalid': 'ignore', 'under': 'ignore'}
>>> int16(32000) * int16(3)
Traceback (most recent call last):
File "<stdin>", line 1, in ?
FloatingPointError: overflow encountered in short_scalars
+ >>> seterr(all='ignore')
+ {'over': 'ignore', 'divide': 'ignore', 'invalid': 'ignore', 'under': 'ignore'}
+
"""
pyvals = umath.geterrobj()
old = geterr()
- if divide is None: divide = old['divide']
- if over is None: over = old['over']
- if under is None: under = old['under']
- if invalid is None: invalid = old['invalid']
+ if divide is None: divide = all or old['divide']
+ if over is None: over = all or old['over']
+ if under is None: under = all or old['under']
+ if invalid is None: invalid = all or old['invalid']
maskvalue = ((_errdict[divide] << SHIFT_DIVIDEBYZERO) +
(_errdict[over] << SHIFT_OVERFLOW ) +
@@ -653,6 +661,7 @@ def seterr(divide=None, over=None, under=None, invalid=None):
umath.seterrobj(pyvals)
return old
+
def geterr():
"""Get the current way of handling floating-point errors.
@@ -718,12 +727,38 @@ def geterrcall():
return umath.geterrobj()[2]
class errstate(object):
+ """with errstate(**state): --> operations in following block use given state.
+
+ # Set error handling to known state.
+ >>> _ = seterr(invalid='raise', divide='raise', over='raise', under='ignore')
+
+ |>> a = -arange(3)
+ |>> with errstate(invalid='ignore'):
+ ... print sqrt(a)
+ [ 0. -1.#IND -1.#IND]
+ |>> print sqrt(a.astype(complex))
+ [ 0. +0.00000000e+00j 0. +1.00000000e+00j 0. +1.41421356e+00j]
+ |>> print sqrt(a)
+ Traceback (most recent call last):
+ ...
+ FloatingPointError: invalid encountered in sqrt
+ |>> with errstate(divide='ignore'):
+ ... print a/0
+ [0 0 0]
+ |>> print a/0
+ Traceback (most recent call last):
+ ...
+ FloatingPointError: divide by zero encountered in divide
+
+ """
+ # Note that we don't want to run the above doctests because they will fail
+ # without a from __future__ import with_statement
def __init__(self, **kwargs):
self.kwargs = kwargs
def __enter__(self):
self.oldstate = seterr(**self.kwargs)
def __exit__(self, *exc_info):
- numpy.seterr(**self.oldstate)
+ seterr(**self.oldstate)
def _setdef():
defval = [UFUNC_BUFSIZE_DEFAULT, ERR_DEFAULT, None]
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index 44a69f8e8..ea2486972 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -2,6 +2,7 @@ from numpy.core import *
from numpy.random import rand, randint
from numpy.testing import *
from numpy.core.multiarray import dot as dot_
+import sys
class Vec:
def __init__(self,sequence=None):
@@ -246,6 +247,12 @@ class test_binary_repr(NumpyTestCase):
def test_large(self):
assert_equal(binary_repr(10736848),'101000111101010011010000')
+
+import sys
+if sys.version_info[:2] >= (2, 5):
+ set_local_path()
+ from test_errstate import *
+ restore_path()
if __name__ == '__main__':
NumpyTest().run()