summaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
authorOleg Shyshkov <shyshkov@google.com>2023-05-12 17:14:50 +0200
committerOleg Shyshkov <shyshkov@google.com>2023-05-12 17:18:13 +0200
commitcad08503b8d5ffc3834ab2f3e10f9cf44f6f0ee3 (patch)
treebb0050fc1cb76718c27c3b4667942ddd367b8132 /mlir/lib
parentd0718ff410cc0911766b86cd162d19b0b780b0ea (diff)
downloadllvm-cad08503b8d5ffc3834ab2f3e10f9cf44f6f0ee3.tar.gz
[mlir][memref] Lower copy of memrefs with outer size-1 dims to intrinsic memcpy.
With this change, more `memref.copy` will be lowered to the efficient `memcpy`. For example, ``` memref.copy %subview, %alloc : memref<1x576xf32, strided<[704, 1]>> to memref<1x576xf32> ``` Differential Revision: https://reviews.llvm.org/D150448
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp20
1 files changed, 15 insertions, 5 deletions
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 1a6e5a4e8dbd..013baef3dc07 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1064,13 +1064,23 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
if (failed(getStridesAndOffset(type, strides, offset)))
return false;
+ // MemRef is contiguous if outer dimensions are size-1 and inner
+ // dimensions have unit strides.
int64_t runningStride = 1;
- for (unsigned i = strides.size(); i > 0; --i) {
- if (strides[i - 1] != runningStride)
- return false;
- runningStride *= type.getDimSize(i - 1);
+ int64_t curDim = strides.size() - 1;
+ // Finds all inner dimensions with unit strides.
+ while (curDim >= 0 && strides[curDim] == runningStride) {
+ runningStride *= type.getDimSize(curDim);
+ --curDim;
}
- return true;
+
+ // Check if other dimensions are size-1.
+ while (curDim >= 0 && type.getDimSize(curDim) == 1) {
+ --curDim;
+ }
+
+ // All dims are unit-strided or size-1.
+ return curDim < 0;
};
auto isContiguousMemrefType = [&](BaseMemRefType type) {