summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicolas Vasilache <ntv@google.com>2022-01-09 11:51:58 -0500
committerNicolas Vasilache <ntv@google.com>2022-01-09 14:13:08 -0500
commit9ba25ec92d88639561797674296b81fb3b67eed5 (patch)
tree67d352ff4437d943aac801774228870c20711c99
parent1ce01b7dfe8247c25b25e0ed44b7f1e41599bb43 (diff)
downloadllvm-9ba25ec92d88639561797674296b81fb3b67eed5.tar.gz
[mlir][Bufferize] NFC - Introduce areCastCompatible assertions to catch misformed CastOp early
Differential Revision: https://reviews.llvm.org/D116893
-rw-r--r--mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp3
-rw-r--r--mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp6
-rw-r--r--mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp14
-rw-r--r--mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp3
4 files changed, 24 insertions, 2 deletions
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index d2d726312d6a..e64d5ae3dda6 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -549,6 +549,9 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
return failure();
Value casted = allocated.getValue();
if (memRefType && memRefType != allocMemRefType) {
+ assert(memref::CastOp::areCastCompatible(allocated.getValue().getType(),
+ memRefType) &&
+ "createAlloc: cast incompatible");
casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
index 17719244c5c3..fd3632fb56d0 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
@@ -77,9 +77,13 @@ struct ToMemrefOpInterface
// Insert cast in case to_memref(to_tensor(x))'s type is different from
// x's type.
- if (toTensorOp.memref().getType() != toMemrefOp.getType())
+ if (toTensorOp.memref().getType() != toMemrefOp.getType()) {
+ assert(memref::CastOp::areCastCompatible(buffer.getType(),
+ toMemrefOp.getType()) &&
+ "ToMemrefOp::bufferize : cast incompatible");
buffer = rewriter.create<memref::CastOp>(toMemrefOp.getLoc(), buffer,
toMemrefOp.getType());
+ }
replaceOpWithBufferizedValues(rewriter, toMemrefOp, buffer);
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 8c6b32d73317..8138ab2952ea 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -386,7 +386,10 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
// Replace all uses of bbArg through a ToMemRefOp by a memref::CastOp.
for (auto &use : llvm::make_early_inc_range(bbArg.getUses())) {
if (auto toMemrefOp =
- dyn_cast<bufferization::ToMemrefOp>(use.getOwner())) {
+ dyn_cast<bufferization::ToMemrefOp>(use.getOwner())) {
+ assert(memref::CastOp::areCastCompatible(
+ memref.getType(), toMemrefOp.memref().getType()) &&
+ "bufferizeFuncOpBoundary: cast incompatible");
auto castOp = b.create<memref::CastOp>(
funcOp.getLoc(), toMemrefOp.memref().getType(), memref);
toMemrefOp.memref().replaceAllUsesWith(castOp);
@@ -525,6 +528,8 @@ static void layoutPostProcessing(ModuleOp moduleOp) {
bbArg.setType(desiredMemrefType);
OpBuilder b(bbArg.getContext());
b.setInsertionPointToStart(bbArg.getOwner());
+ assert(memref::CastOp::areCastCompatible(bbArg.getType(), memrefType) &&
+ "layoutPostProcessing: cast incompatible");
// Cast back to the original memrefType and let it canonicalize.
Value cast =
b.create<memref::CastOp>(funcOp.getLoc(), memrefType, bbArg);
@@ -537,6 +542,10 @@ static void layoutPostProcessing(ModuleOp moduleOp) {
// such cases.
auto castArg = [&](Operation *caller) {
OpBuilder b(caller);
+ assert(
+ memref::CastOp::areCastCompatible(
+ caller->getOperand(argNumber).getType(), desiredMemrefType) &&
+ "layoutPostProcessing.2: cast incompatible");
Value newOperand = b.create<memref::CastOp>(
funcOp.getLoc(), desiredMemrefType, caller->getOperand(argNumber));
operandsPerCaller.find(caller)->getSecond().push_back(newOperand);
@@ -703,6 +712,9 @@ struct CallOpInterface
// that will either canonicalize away or fail compilation until we can do
// something better.
if (buffer.getType() != memRefType) {
+ assert(
+ memref::CastOp::areCastCompatible(buffer.getType(), memRefType) &&
+ "CallOp::bufferize: cast incompatible");
Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(),
memRefType, buffer);
buffer = castBuffer;
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index 7c9114b284b2..f0f20b433937 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -77,6 +77,9 @@ struct CastOpInterface
}
// Replace the op with a memref.cast.
+ assert(memref::CastOp::areCastCompatible(resultBuffer->getType(),
+ resultMemRefType) &&
+ "CallOp::bufferize: cast incompatible");
replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
*resultBuffer);