fast take patch

classic Classic list List threaded Threaded
2 messages Options
Reply | Threaded
Open this post in threaded view
|

fast take patch

Eric Firing
Stefan, (or anyone else who might be interested)

Since you committed my fast putmask patch many months ago, I thought you
might like to deal with my fast take patch.  Attached is the diff
relative to 5043, ignoring whitespace.  (Yes, those pesky whitespace
anomalies are still cropping up.)

The motivation for the patch is that in matplotlib color mapping, such
as when displaying an image or pcolor plot, the time-limiting factor can
be a single numpy operation of indexing into the color map.
(http://www.mail-archive.com/matplotlib-users@.../msg06482.html)
It was being done with fancy indexing.  I changed that to use the take
method, which sped it up by a factor of two, but I found that take
itself is slower than it needs to be by a factor of 2-3.  As with
putmask, the culprit is memmove, so I used the same strategy to
substitute direct variable assignment when possible.

The patch includes tests for the take method.  I did not find any such
pre-existing tests.

Thanks for taking a look.  Let me know what needs to be changed.

Eric

Index: numpy/core/include/numpy/ndarrayobject.h
===================================================================
--- numpy/core/include/numpy/ndarrayobject.h (revision 5043)
+++ numpy/core/include/numpy/ndarrayobject.h (working copy)
@@ -1052,6 +1052,10 @@
                                     void *max, void *out);
 typedef void (PyArray_FastPutmaskFunc)(void *in, void *mask, npy_intp n_in,
                                        void *values, npy_intp nv);
+typedef int  (PyArray_FastTakeFunc)(void *dest, void *src, npy_intp *indarray,
+                                       npy_intp nindarray, npy_intp n_outer,
+                                       npy_intp m_middle, npy_intp nelem,
+                                       NPY_CLIPMODE clipmode);
 
 typedef struct {
         npy_intp *ptr;
@@ -1130,6 +1134,7 @@
 
         PyArray_FastClipFunc *fastclip;
         PyArray_FastPutmaskFunc *fastputmask;
+        PyArray_FastTakeFunc *fasttake;
 } PyArray_ArrFuncs;
 
 #define NPY_ITEM_REFCOUNT   0x01  /* The item must be reference counted
Index: numpy/core/src/multiarraymodule.c
===================================================================
--- numpy/core/src/multiarraymodule.c (revision 5043)
+++ numpy/core/src/multiarraymodule.c (working copy)
@@ -3824,11 +3824,13 @@
 PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis,
                  PyArrayObject *ret, NPY_CLIPMODE clipmode)
 {
+    PyArray_FastTakeFunc *func;
     PyArrayObject *self, *indices;
-    intp nd, i, j, n, m, max_item, tmp, chunk;
+    intp nd, i, j, n, m, max_item, tmp, chunk, nelem;
     intp shape[MAX_DIMS];
     char *src, *dest;
     int copyret=0;
+    int err;
 
     indices = NULL;
     self = (PyAO *)_check_axis(self0, &axis, CARRAY);
@@ -3892,10 +3894,14 @@
     }
 
     max_item = self->dimensions[axis];
+    nelem = chunk;
     chunk = chunk * ret->descr->elsize;
     src = self->data;
     dest = ret->data;
 
+    func = self->descr->f->fasttake;
+    if (func == NULL) {
+
     switch(clipmode) {
     case NPY_RAISE:
         for(i=0; i<n; i++) {
@@ -3943,6 +3949,12 @@
         }
         break;
     }
+    }
+    else {
+        err = func(dest, src, (intp *)(indices->data),
+                    max_item, n, m, nelem, clipmode);
+        if (err) goto fail;
+    }
 
     PyArray_INCREF(ret);
 
Index: numpy/core/src/arraytypes.inc.src
===================================================================
--- numpy/core/src/arraytypes.inc.src (revision 5043)
+++ numpy/core/src/arraytypes.inc.src (working copy)
@@ -2179,6 +2179,89 @@
 #define OBJECT_fastputmask NULL
 
 
+
+/************************
+ * Fast take functions
+ *************************/
+
+/**begin repeat
+#name=BOOL,BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG, LONGLONG, ULONGLONG, FLOAT, DOUBLE, LONGDOUBLE,CFLOAT, CDOUBLE, CLONGDOUBLE#
+#type= Bool, byte, ubyte, short, ushort, int, uint, long, ulong, longlong, ulonglong, float, double, longdouble,cfloat, cdouble, clongdouble#
+*/
+static int
+@name@_fasttake(@type@ *dest, @type@ *src, intp *indarray,
+                    intp nindarray, intp n_outer,
+                    intp m_middle, intp nelem,
+                    NPY_CLIPMODE clipmode)
+{
+    intp i, j, k, tmp;
+
+    switch(clipmode) {
+    case NPY_RAISE:
+        for(i=0; i<n_outer; i++) {
+            for(j=0; j<m_middle; j++) {
+                tmp = indarray[j];
+                if (tmp < 0) tmp = tmp+nindarray;
+                if ((tmp < 0) || (tmp >= nindarray)) {
+                    PyErr_SetString(PyExc_IndexError,
+                                    "index out of range "\
+                                    "for array");
+                    return 1;
+                }
+                if (nelem == 1) *dest++ = *(src+tmp);
+                else {
+                    for(k=0; k<nelem; k++) {
+                        *dest++ = *(src+tmp*nelem+k);
+                    }
+                }
+            }
+            src += nelem*nindarray;
+        }
+        break;
+    case NPY_WRAP:
+        for(i=0; i<n_outer; i++) {
+            for(j=0; j<m_middle; j++) {
+                tmp = indarray[j];
+                if (tmp < 0) while (tmp < 0) tmp += nindarray;
+                else if (tmp >= nindarray)
+                    while (tmp >= nindarray)
+                        tmp -= nindarray;
+                if (nelem == 1) *dest++ = *(src+tmp);
+                else {
+                    for(k=0; k<nelem; k++) {
+                        *dest++ = *(src+tmp*nelem+k);
+                    }
+                }
+            }
+            src += nelem*nindarray;
+        }
+        break;
+    case NPY_CLIP:
+        for(i=0; i<n_outer; i++) {
+            for(j=0; j<m_middle; j++) {
+                tmp = indarray[j];
+                if (tmp < 0)
+                    tmp = 0;
+                else if (tmp >= nindarray)
+                    tmp = nindarray-1;
+                if (nelem == 1) *dest++ = *(src+tmp);
+                else {
+                    for(k=0; k<nelem; k++) {
+                        *dest++ = *(src+tmp*nelem+k);
+                    }
+                }
+            }
+            src += nelem*nindarray;
+        }
+        break;
+    }
+    return 0;
+}
+/**end repeat**/
+
+#define OBJECT_fasttake NULL
+
+
 #define _ALIGN(type) offsetof(struct {char c; type v;},v)
 
 /* Disable harmless compiler warning "4116: unnamed type definition in
@@ -2244,7 +2327,8 @@
     NULL,
     NULL,
     (PyArray_FastClipFunc *)NULL,
-    (PyArray_FastPutmaskFunc *)NULL
+    (PyArray_FastPutmaskFunc *)NULL,
+    (PyArray_FastTakeFunc *)NULL
 };
 
 static PyArray_Descr @from@_Descr = {
@@ -2322,7 +2406,8 @@
     NULL,
     NULL,
     (PyArray_FastClipFunc*)@from@_fastclip,
-    (PyArray_FastPutmaskFunc*)@from@_fastputmask
+    (PyArray_FastPutmaskFunc*)@from@_fastputmask,
+    (PyArray_FastTakeFunc*)@from@_fasttake
 };
 
 static PyArray_Descr @from@_Descr = {
Index: numpy/core/tests/test_multiarray.py
===================================================================
--- numpy/core/tests/test_multiarray.py (revision 5043)
+++ numpy/core/tests/test_multiarray.py (working copy)
@@ -699,6 +699,56 @@
         ## np.putmask(z,[True,True,True],3)
         pass
 
+class TestTake(ParametricTestCase):
+    def tst_basic(self,x):
+        ind = range(x.shape[0])
+        assert_array_equal(x.take(ind, axis=0), x)
+
+    def testip_types(self):
+        unchecked_types = [str, unicode, np.void, object]
+
+        x = np.random.random(24)*100
+        x.shape = 2,3,4
+        tests = []
+        for types in np.sctypes.itervalues():
+            tests.extend([(self.tst_basic,x.copy().astype(T))
+                          for T in types if T not in unchecked_types])
+        return tests
+
+    def test_raise(self):
+        x = np.random.random(24)*100
+        x.shape = 2,3,4
+        self.failUnlessRaises(IndexError, x.take, [0,1,2], axis=0)
+        self.failUnlessRaises(IndexError, x.take, [-3], axis=0)
+        assert_array_equal(x.take([-1], axis=0)[0], x[1])
+
+    def test_clip(self):
+        x = np.random.random(24)*100
+        x.shape = 2,3,4
+        assert_array_equal(x.take([-1], axis=0, mode='clip')[0], x[0])
+        assert_array_equal(x.take([2], axis=0, mode='clip')[0], x[1])
+
+    def test_wrap(self):
+        x = np.random.random(24)*100
+        x.shape = 2,3,4
+        assert_array_equal(x.take([-1], axis=0, mode='wrap')[0], x[1])
+        assert_array_equal(x.take([2], axis=0, mode='wrap')[0], x[0])
+        assert_array_equal(x.take([3], axis=0, mode='wrap')[0], x[1])
+
+    def tst_byteorder(self,dtype):
+        x = np.array([1,2,3],dtype)
+        assert_array_equal(x.take([0,2,1]),[1,3,2])
+
+    def testip_byteorder(self):
+        return [(self.tst_byteorder,dtype) for dtype in ('>i4','<i4')]
+
+    def test_record_array(self):
+        # Note mixed byteorder.
+        rec = np.array([(-5, 2.0, 3.0), (5.0, 4.0, 3.0)],
+                      dtype=[('x', '<f8'), ('y', '>f8'), ('z', '<f8')])
+        rec1 = rec.take([1])
+        assert rec1['x'] == 5.0 and rec1['y'] == 4.0
+
 class TestLexsort(NumpyTestCase):
     def test_basic(self):
         a = [1,2,1,3,1,5]

_______________________________________________
Numpy-discussion mailing list
[hidden email]
http://projects.scipy.org/mailman/listinfo/numpy-discussion
Reply | Threaded
Open this post in threaded view
|

Re: fast take patch

Stéfan van der Walt
Hi Eric

On 18/04/2008, Eric Firing <[hidden email]> wrote:
>  The motivation for the patch is that in matplotlib color mapping, such as

[...]

Beautiful patch, good motivation.  Thank you!

Applied in r5044.

Regards
Stéfan
_______________________________________________
Numpy-discussion mailing list
[hidden email]
http://projects.scipy.org/mailman/listinfo/numpy-discussion