diff options
author | Oleg Shyshkov <shyshkov@google.com> | 2023-05-12 17:14:50 +0200 |
---|---|---|
committer | Oleg Shyshkov <shyshkov@google.com> | 2023-05-12 17:18:13 +0200 |
commit | cad08503b8d5ffc3834ab2f3e10f9cf44f6f0ee3 (patch) | |
tree | bb0050fc1cb76718c27c3b4667942ddd367b8132 /mlir/lib | |
parent | d0718ff410cc0911766b86cd162d19b0b780b0ea (diff) | |
download | llvm-cad08503b8d5ffc3834ab2f3e10f9cf44f6f0ee3.tar.gz |
[mlir][memref] Lower copy of memrefs with outer size-1 dims to intrinsic memcpy.
With this change, more `memref.copy` will be lowered to the efficient `memcpy`. For example,
```
memref.copy %subview, %alloc : memref<1x576xf32, strided<[704, 1]>> to memref<1x576xf32>
```
Differential Revision: https://reviews.llvm.org/D150448
Diffstat (limited to 'mlir/lib')
-rw-r--r-- | mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 20 |
1 files changed, 15 insertions, 5 deletions
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 1a6e5a4e8dbd..013baef3dc07 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1064,13 +1064,23 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> { if (failed(getStridesAndOffset(type, strides, offset))) return false; + // MemRef is contiguous if outer dimensions are size-1 and inner + // dimensions have unit strides. int64_t runningStride = 1; - for (unsigned i = strides.size(); i > 0; --i) { - if (strides[i - 1] != runningStride) - return false; - runningStride *= type.getDimSize(i - 1); + int64_t curDim = strides.size() - 1; + // Finds all inner dimensions with unit strides. + while (curDim >= 0 && strides[curDim] == runningStride) { + runningStride *= type.getDimSize(curDim); + --curDim; } - return true; + + // Check if other dimensions are size-1. + while (curDim >= 0 && type.getDimSize(curDim) == 1) { + --curDim; + } + + // All dims are unit-strided or size-1. + return curDim < 0; }; auto isContiguousMemrefType = [&](BaseMemRefType type) { |