diff options
author | Mark Florisson <markflorisson88@gmail.com> | 2012-05-05 17:15:48 +0100 |
---|---|---|
committer | Mark Florisson <markflorisson88@gmail.com> | 2012-05-11 11:48:57 +0100 |
commit | 75c761cf13bf67fb58c124a47fea5f5d8e817834 (patch) | |
tree | 5c86e2621c7e20aac14c6edf4d3a5aafc7700d2a /Cython/Compiler/MemoryView.py | |
parent | 0edf8d2a38468f317e7350b4fc0ed3d359652831 (diff) | |
download | cython-75c761cf13bf67fb58c124a47fea5f5d8e817834.tar.gz |
Support newaxis indexing for memoryview slices
todo: support memoryview object newaxis indexing
Diffstat (limited to 'Cython/Compiler/MemoryView.py')
-rw-r--r-- | Cython/Compiler/MemoryView.py | 69 |
1 files changed, 42 insertions, 27 deletions
diff --git a/Cython/Compiler/MemoryView.py b/Cython/Compiler/MemoryView.py index 24ab79e4b..135a1f39f 100644 --- a/Cython/Compiler/MemoryView.py +++ b/Cython/Compiler/MemoryView.py @@ -277,8 +277,8 @@ class MemoryViewSliceBufferEntry(Buffer.BufferEntry): """ Slice a memoryviewslice. - indices - list of index nodes. If not a SliceNode, then it must be - coercible to Py_ssize_t + indices - list of index nodes. If not a SliceNode, or NoneNode, + then it must be coercible to Py_ssize_t Simply call __pyx_memoryview_slice_memviewslice with the right arguments. @@ -307,28 +307,14 @@ class MemoryViewSliceBufferEntry(Buffer.BufferEntry): code.putln("%(dst)s.memview = %(src)s.memview;" % locals()) code.put_incref_memoryviewslice(dst) - for dim, index in enumerate(indices): + dim = -1 + for index in indices: error_goto = code.error_goto(index.pos) - - if not isinstance(index, ExprNodes.SliceNode): - # normal index - idx = index.result() - + if not index.is_none: + dim += 1 access, packing = self.type.axes[dim] - if access == 'direct': - indirect = False - else: - indirect = True - generic = (access == 'full') - if new_ndim != 0: - return error(index.pos, - "All preceding dimensions must be " - "indexed and not sliced") - - d = locals() - code.put(load_slice_util("SliceIndex", d)) - else: + if isinstance(index, ExprNodes.SliceNode): # slice, unspecified dimension, or part of ellipsis d = locals() for s in "start stop step".split(): @@ -344,7 +330,6 @@ class MemoryViewSliceBufferEntry(Buffer.BufferEntry): not d['have_step']): # full slice (:), simply copy over the extent, stride # and suboffset. Also update suboffset_dim if needed - access, packing = self.type.axes[dim] d['access'] = access code.put(load_slice_util("SimpleSlice", d)) else: @@ -352,6 +337,31 @@ class MemoryViewSliceBufferEntry(Buffer.BufferEntry): new_ndim += 1 + elif index.is_none: + # newaxis + attribs = [('shape', 1), ('strides', 0), ('suboffsets', -1)] + for attrib, value in attribs: + code.putln("%s.%s[%d] = %d;" % (dst, attrib, new_ndim, value)) + + new_ndim += 1 + + else: + # normal index + idx = index.result() + + if access == 'direct': + indirect = False + else: + indirect = True + generic = (access == 'full') + if new_ndim != 0: + return error(index.pos, + "All preceding dimensions must be " + "indexed and not sliced") + + d = locals() + code.put(load_slice_util("SliceIndex", d)) + if not no_suboffset_dim: code.funcstate.release_temp(suboffset_dim) @@ -361,11 +371,13 @@ def empty_slice(pos): return ExprNodes.SliceNode(pos, start=none, stop=none, step=none) -def unellipsify(indices, ndim): +def unellipsify(indices, newaxes, ndim): result = [] seen_ellipsis = False have_slices = False + n_indices = len(indices) - len(newaxes) + for index in indices: if isinstance(index, ExprNodes.EllipsisNode): have_slices = True @@ -374,16 +386,19 @@ def unellipsify(indices, ndim): if seen_ellipsis: result.append(full_slice) else: - nslices = ndim - len(indices) + 1 + nslices = ndim - n_indices + 1 result.extend([full_slice] * nslices) seen_ellipsis = True else: - have_slices = have_slices or isinstance(index, ExprNodes.SliceNode) + have_slices = (have_slices or + isinstance(index, ExprNodes.SliceNode) or + index.is_none) result.append(index) - if len(result) < ndim: + result_length = len(result) - len(newaxes) + if result_length < ndim: have_slices = True - nslices = ndim - len(result) + nslices = ndim - result_length result.extend([empty_slice(indices[-1].pos)] * nslices) return have_slices, result |