summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorsasha <sasha@localhost>2006-03-01 03:41:58 +0000
committersasha <sasha@localhost>2006-03-01 03:41:58 +0000
commitb25ddc20561f810283b7a6ecb49910f45df9770d (patch)
tree2bb1dffd738b69d1cc41533fe7fb659e349e413a /numpy/core
parent8f6eca691ad76cf0264ed9c8e2893e9160efe990 (diff)
downloadnumpy-b25ddc20561f810283b7a6ecb49910f45df9770d.tar.gz
faster ndarray.fill
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/include/numpy/arrayobject.h5
-rw-r--r--numpy/core/src/arrayobject.c14
-rw-r--r--numpy/core/src/arraytypes.inc.src41
-rw-r--r--numpy/core/tests/test_multiarray.py13
4 files changed, 69 insertions, 4 deletions
diff --git a/numpy/core/include/numpy/arrayobject.h b/numpy/core/include/numpy/arrayobject.h
index e02fd3df0..872e4d83a 100644
--- a/numpy/core/include/numpy/arrayobject.h
+++ b/numpy/core/include/numpy/arrayobject.h
@@ -811,6 +811,8 @@ typedef int (PyArray_FillFunc)(void *, intp, void *);
typedef int (PyArray_SortFunc)(void *, intp, void *);
typedef int (PyArray_ArgSortFunc)(void *, intp *, intp, void *);
+typedef int (PyArray_FillWithScalarFunc)(void *, intp, intp, void *);
+
typedef struct {
intp *ptr;
int len;
@@ -853,6 +855,9 @@ typedef struct {
/* Used for arange */
PyArray_FillFunc *fill;
+ /* Function to fill arrays with scalar values */
+ PyArray_FillWithScalarFunc *fillwithscalar;
+
/* Sorting functions */
PyArray_SortFunc *sort[PyArray_NSORTS];
PyArray_ArgSortFunc *argsort[PyArray_NSORTS];
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c
index 55abc36f0..cd83f31e1 100644
--- a/numpy/core/src/arrayobject.c
+++ b/numpy/core/src/arrayobject.c
@@ -4108,9 +4108,17 @@ PyArray_FillWithScalar(PyArrayObject *arr, PyObject *obj)
copyswap = arr->descr->f->copyswap;
if (PyArray_ISONESEGMENT(arr)) {
char *toptr=PyArray_DATA(arr);
- while (size--) {
- copyswap(toptr, fromptr, swap, itemsize);
- toptr += itemsize;
+ PyArray_FillWithScalarFunc* fillwithscalar =
+ arr->descr->f->fillwithscalar;
+ if (fillwithscalar && PyArray_ISALIGNED(arr)) {
+ copyswap(fromptr, NULL, swap, itemsize);
+ fillwithscalar(toptr, size, itemsize, fromptr);
+ }
+ else {
+ while (size--) {
+ copyswap(toptr, fromptr, swap, itemsize);
+ toptr += itemsize;
+ }
}
}
else {
diff --git a/numpy/core/src/arraytypes.inc.src b/numpy/core/src/arraytypes.inc.src
index 0fd105ab1..2c570b946 100644
--- a/numpy/core/src/arraytypes.inc.src
+++ b/numpy/core/src/arraytypes.inc.src
@@ -1713,6 +1713,45 @@ static void
/**end repeat**/
+/* this requires buffer to be filled with objects or NULL */
+static void
+OBJECT_fillwithscalar(PyObject **buffer, intp length, intp ignored, PyObject **value)
+{
+ intp i;
+ PyObject *val = *value;
+ for (i=0; i<length; i++) {
+ Py_XDECREF(buffer[i]);
+ Py_INCREF(val);
+ buffer[i] = val;
+ }
+}
+/**begin repeat
+#NAME=BOOL,BYTE,UBYTE#
+#typ=Bool,byte,ubyte#
+*/
+static void
+@NAME@_fillwithscalar(@typ@ *buffer, intp length, intp ignored, @typ@ *value)
+{
+ memset(buffer, *value, length);
+}
+/**end repeat**/
+
+/**begin repeat
+#NAME=SHORT,USHORT,INT,UINT,LONG,ULONG,LONGLONG,ULONGLONG,FLOAT,DOUBLE,LONGDOUBLE,CFLOAT,CDOUBLE,CLONGDOUBLE#
+#typ=short,ushort,int,uint,long,ulong,longlong,ulonglong,float,double,longdouble,cfloat,cdouble,clongdouble#
+*/
+static void
+@NAME@_fillwithscalar(@typ@ *buffer, intp length, intp ignored, @typ@ *value)
+{
+ register intp i;
+ @typ@ val = *value;
+ for (i=0; i<length; ++i) {
+ buffer[i] = val;
+ }
+}
+
+/**end repeat**/
+
#define _ALIGN(type) offsetof(struct {char c; type v;},v)
/**begin repeat
@@ -1758,6 +1797,7 @@ static PyArray_ArrFuncs _Py@NAME@_ArrFuncs = {
(PyArray_FromStrFunc*)@from@_fromstr,
(PyArray_NonzeroFunc*)@from@_nonzero,
(PyArray_FillFunc*)NULL,
+ (PyArray_FillWithScalarFunc*)NULL,
{
NULL, NULL, NULL, NULL
},
@@ -1828,6 +1868,7 @@ static PyArray_ArrFuncs _Py@NAME@_ArrFuncs = {
(PyArray_FromStrFunc*)@from@_fromstr,
(PyArray_NonzeroFunc*)@from@_nonzero,
(PyArray_FillFunc*)@from@_fill,
+ (PyArray_FillWithScalarFunc*)@from@_fillwithscalar,
{
NULL, NULL, NULL, NULL
},
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 25bd98a45..87093467b 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -90,7 +90,18 @@ class test_attributes(ScipyTestCase):
self.failUnlessRaises(ValueError, make_array, 4, 2, -1)
self.failUnlessRaises(ValueError, make_array, 8, 3, 1)
#self.failUnlessRaises(ValueError, make_array, 8, 3, 0)
-
+
+ def check_fill(self):
+ for t in "?bhilqpBHILQPfdgFDGO":
+ x = empty((3,2,1), t)
+ y = empty((3,2,1), t)
+ x.fill(1)
+ y[...] = 1
+ assert_equal(x,y)
+
+ x = array([(0,0.0), (1,1.0)], dtype='i4,f8')
+ x.fill(x[0])
+ assert_equal(x['f1'][1], x['f1'][0])
class test_dtypedescr(ScipyTestCase):
def check_construction(self):