diff options
author | Oleg Shyshkov <shyshkov@google.com> | 2023-05-15 17:04:03 +0200 |
---|---|---|
committer | Oleg Shyshkov <shyshkov@google.com> | 2023-05-15 17:09:04 +0200 |
commit | b4d6aada623d2cbf7c34713e216679d4f013ac9d (patch) | |
tree | 5c8e5f5078627def63a6a234c3046d4d9dbdb3ce /mlir | |
parent | 36d4e4c9b5f6cd0577b6029055b825caaec2dd11 (diff) | |
download | llvm-b4d6aada623d2cbf7c34713e216679d4f013ac9d.tar.gz |
[mlir][memref] Extract isStaticShapeAndContiguousRowMajor as a util function.
Differential Revision: https://reviews.llvm.org/D150543
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h | 13 | ||||
-rw-r--r-- | mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt | 1 | ||||
-rw-r--r-- | mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 31 | ||||
-rw-r--r-- | mlir/lib/Dialect/MemRef/CMakeLists.txt | 1 | ||||
-rw-r--r-- | mlir/lib/Dialect/MemRef/Utils/CMakeLists.txt | 9 | ||||
-rw-r--r-- | mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp | 48 |
6 files changed, 74 insertions, 29 deletions
diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h index f7e525a7374f..6854c183b449 100644 --- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h +++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h @@ -16,4 +16,17 @@ #ifndef MLIR_DIALECT_MEMREF_UTILS_MEMREFUTILS_H #define MLIR_DIALECT_MEMREF_UTILS_MEMREFUTILS_H +namespace mlir { + +class MemRefType; + +namespace memref { + +/// Returns true, if the memref type has static shapes and represents a +/// contiguous chunk of memory. +bool isStaticShapeAndContiguousRowMajor(MemRefType type); + +} // namespace memref +} // namespace mlir + #endif // MLIR_DIALECT_MEMREF_UTILS_MEMREFUTILS_H diff --git a/mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt index 0fa997bcc25d..1400618c93e8 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/MemRefToLLVM/CMakeLists.txt @@ -16,6 +16,7 @@ add_mlir_conversion_library(MLIRMemRefToLLVM MLIRDataLayoutInterfaces MLIRLLVMCommonConversion MLIRMemRefDialect + MLIRMemRefUtils MLIRLLVMDialect MLIRTransforms ) diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 013baef3dc07..17b9b7404768 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/IRMapping.h" #include "mlir/Pass/Pass.h" @@ -1055,34 +1056,6 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> { auto srcType = cast<BaseMemRefType>(op.getSource().getType()); auto targetType = cast<BaseMemRefType>(op.getTarget().getType()); - auto isStaticShapeAndContiguousRowMajor = [](MemRefType type) { - if (!type.hasStaticShape()) - return false; - - SmallVector<int64_t> strides; - int64_t offset; - if (failed(getStridesAndOffset(type, strides, offset))) - return false; - - // MemRef is contiguous if outer dimensions are size-1 and inner - // dimensions have unit strides. - int64_t runningStride = 1; - int64_t curDim = strides.size() - 1; - // Finds all inner dimensions with unit strides. - while (curDim >= 0 && strides[curDim] == runningStride) { - runningStride *= type.getDimSize(curDim); - --curDim; - } - - // Check if other dimensions are size-1. - while (curDim >= 0 && type.getDimSize(curDim) == 1) { - --curDim; - } - - // All dims are unit-strided or size-1. - return curDim < 0; - }; - auto isContiguousMemrefType = [&](BaseMemRefType type) { auto memrefType = dyn_cast<mlir::MemRefType>(type); // We can use memcpy for memrefs if they have an identity layout or are @@ -1091,7 +1064,7 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> { return memrefType && (memrefType.getLayout().isIdentity() || (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 && - isStaticShapeAndContiguousRowMajor(memrefType))); + memref::isStaticShapeAndContiguousRowMajor(memrefType))); }; if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType)) diff --git a/mlir/lib/Dialect/MemRef/CMakeLists.txt b/mlir/lib/Dialect/MemRef/CMakeLists.txt index 660deb21479d..c47e4c5495c1 100644 --- a/mlir/lib/Dialect/MemRef/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(IR) add_subdirectory(TransformOps) add_subdirectory(Transforms) +add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/MemRef/Utils/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Utils/CMakeLists.txt new file mode 100644 index 000000000000..0af6ba2b7fd4 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Utils/CMakeLists.txt @@ -0,0 +1,9 @@ +add_mlir_dialect_library(MLIRMemRefUtils + MemRefUtils.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRef/Utils + + LINK_LIBS PUBLIC + MLIRIR +) diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp new file mode 100644 index 000000000000..5e42602bc3ea --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp @@ -0,0 +1,48 @@ +//===- MemRefUtils.cpp - Utilities to support the MemRef dialect ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements utilities for the MemRef dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" + +namespace mlir { +namespace memref { + +bool isStaticShapeAndContiguousRowMajor(MemRefType type) { + if (!type.hasStaticShape()) + return false; + + SmallVector<int64_t> strides; + int64_t offset; + if (failed(getStridesAndOffset(type, strides, offset))) + return false; + + // MemRef is contiguous if outer dimensions are size-1 and inner + // dimensions have unit strides. + int64_t runningStride = 1; + int64_t curDim = strides.size() - 1; + // Finds all inner dimensions with unit strides. + while (curDim >= 0 && strides[curDim] == runningStride) { + runningStride *= type.getDimSize(curDim); + --curDim; + } + + // Check if other dimensions are size-1. + while (curDim >= 0 && type.getDimSize(curDim) == 1) { + --curDim; + } + + // All dims are unit-strided or size-1. + return curDim < 0; +}; + +} // namespace memref +} // namespace mlir |