summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Fode <ericfode@gmail.com>2012-07-12 15:35:44 -0400
committerEric Fode <ericfode@gmail.com>2012-07-12 16:00:12 -0400
commita77a7cd5ed6bc12772ba33c151f40fbe73e9d212 (patch)
tree82cd226e60b7b73caecf7645498c12debdd16a5f
parent143fb1874b3ff3875a93accad3e87056a44d77d0 (diff)
downloadnumpy-a77a7cd5ed6bc12772ba33c151f40fbe73e9d212.tar.gz
First attempt at BF for 2028 and added better tests for scalarmath pow function
-rw-r--r--numpy/core/src/scalarmathmodule.c.src19
-rw-r--r--numpy/core/tests/test_scalarmath.py15
2 files changed, 26 insertions, 8 deletions
diff --git a/numpy/core/src/scalarmathmodule.c.src b/numpy/core/src/scalarmathmodule.c.src
index d9f7abc4e..8b66f7135 100644
--- a/numpy/core/src/scalarmathmodule.c.src
+++ b/numpy/core/src/scalarmathmodule.c.src
@@ -494,16 +494,25 @@ half_ctype_remainder(npy_half a, npy_half b, npy_half *out) {
/**end repeat**/
/**begin repeat
- * #name = half, float, double, longdouble#
- * #type = npy_half, npy_float, npy_double, npy_longdouble#
+ * #name = float, double, longdouble#
+ * #type = npy_float, npy_double, npy_longdouble#
*/
static npy_@name@ (*_basic_@name@_pow)(@type@ a, @type@ b);
+//called when ** is used (not performing properly)
static void
@name@_ctype_power(@type@ a, @type@ b, @type@ *out) {
- *out = _basic_@name@_pow(a, b);
+ *out = _basic_@name@_pow(a, b);
}
/**end repeat**/
+static void
+half_ctype_power(npy_half a,npy_half b, npy_half *out)
+{
+ const npy_float af = npy_half_to_float(a);
+ const npy_float bf = npy_half_to_float(b);
+ const npy_float of = _basic_float_pow(af,bf);
+ *out = npy_float_to_half(of);
+}
/**begin repeat
* #name = byte, ubyte, short, ushort, int, uint,
@@ -970,7 +979,7 @@ static PyObject *
int retstatus;
int first;
@type@ out = {@zero@, @zero@};
-
+
switch(_@name@_convert2_to_ctypes(a, &arg1, b, &arg2)) {
case 0:
break;
@@ -1130,7 +1139,6 @@ static PyObject *
int first;
@type@ out = @zero@;
-
switch(_@name@_convert2_to_ctypes(a, &arg1, b, &arg2)) {
case 0:
break;
@@ -1724,7 +1732,6 @@ get_functions(void)
i += 3;
j++;
}
- _basic_half_pow = funcdata[j - 1];
_basic_float_pow = funcdata[j];
_basic_double_pow = funcdata[j + 1];
_basic_longdouble_pow = funcdata[j + 2];
diff --git a/numpy/core/tests/test_scalarmath.py b/numpy/core/tests/test_scalarmath.py
index 24b5eae24..3f25d008f 100644
--- a/numpy/core/tests/test_scalarmath.py
+++ b/numpy/core/tests/test_scalarmath.py
@@ -44,7 +44,7 @@ class TestTypes(TestCase):
class TestPower(TestCase):
def test_small_types(self):
- for t in [np.int8, np.int16]:
+ for t in [np.int8, np.int16, np.float16]:
a = t(3)
b = a ** 4
assert_(b == 81, "error with %r: got %r" % (t,b))
@@ -58,7 +58,18 @@ class TestPower(TestCase):
assert_(b == 6765201, msg)
else:
assert_almost_equal(b, 6765201, err_msg=msg)
-
+ def test_mixed_types(self):
+ typelist = [np.int8,np.int16,np.float16,np.float32,np.float64,np.int8,np.int16,np.int32,np.int64]
+ for t1 in typelist:
+ for t2 in typelist:
+ a = t1(3)
+ b = t2(2)
+ o = a**b
+ msg = "error with %r and %r: got %r, expected %r" % (t1,t2,o,9)
+ if np.issubdtype(np.dtype(o),np.integer):
+ assert_(o == 9,msg)
+ else:
+ assert_almost_equal(o,9,err_msg=msg)
class TestComplexDivision(TestCase):
def test_zero_division(self):