diff options
author | Nicolas Vasilache <ntv@google.com> | 2022-01-09 11:51:58 -0500 |
---|---|---|
committer | Nicolas Vasilache <ntv@google.com> | 2022-01-09 14:13:08 -0500 |
commit | 9ba25ec92d88639561797674296b81fb3b67eed5 (patch) | |
tree | 67d352ff4437d943aac801774228870c20711c99 | |
parent | 1ce01b7dfe8247c25b25e0ed44b7f1e41599bb43 (diff) | |
download | llvm-9ba25ec92d88639561797674296b81fb3b67eed5.tar.gz |
[mlir][Bufferize] NFC - Introduce areCastCompatible assertions to catch misformed CastOp early
Differential Revision: https://reviews.llvm.org/D116893
4 files changed, 24 insertions, 2 deletions
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp index d2d726312d6a..e64d5ae3dda6 100644 --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -549,6 +549,9 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc( return failure(); Value casted = allocated.getValue(); if (memRefType && memRefType != allocMemRefType) { + assert(memref::CastOp::areCastCompatible(allocated.getValue().getType(), + memRefType) && + "createAlloc: cast incompatible"); casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue()); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp index 17719244c5c3..fd3632fb56d0 100644 --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp @@ -77,9 +77,13 @@ struct ToMemrefOpInterface // Insert cast in case to_memref(to_tensor(x))'s type is different from // x's type. - if (toTensorOp.memref().getType() != toMemrefOp.getType()) + if (toTensorOp.memref().getType() != toMemrefOp.getType()) { + assert(memref::CastOp::areCastCompatible(buffer.getType(), + toMemrefOp.getType()) && + "ToMemrefOp::bufferize : cast incompatible"); buffer = rewriter.create<memref::CastOp>(toMemrefOp.getLoc(), buffer, toMemrefOp.getType()); + } replaceOpWithBufferizedValues(rewriter, toMemrefOp, buffer); return success(); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp index 8c6b32d73317..8138ab2952ea 100644 --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -386,7 +386,10 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp, // Replace all uses of bbArg through a ToMemRefOp by a memref::CastOp. for (auto &use : llvm::make_early_inc_range(bbArg.getUses())) { if (auto toMemrefOp = - dyn_cast<bufferization::ToMemrefOp>(use.getOwner())) { + dyn_cast<bufferization::ToMemrefOp>(use.getOwner())) { + assert(memref::CastOp::areCastCompatible( + memref.getType(), toMemrefOp.memref().getType()) && + "bufferizeFuncOpBoundary: cast incompatible"); auto castOp = b.create<memref::CastOp>( funcOp.getLoc(), toMemrefOp.memref().getType(), memref); toMemrefOp.memref().replaceAllUsesWith(castOp); @@ -525,6 +528,8 @@ static void layoutPostProcessing(ModuleOp moduleOp) { bbArg.setType(desiredMemrefType); OpBuilder b(bbArg.getContext()); b.setInsertionPointToStart(bbArg.getOwner()); + assert(memref::CastOp::areCastCompatible(bbArg.getType(), memrefType) && + "layoutPostProcessing: cast incompatible"); // Cast back to the original memrefType and let it canonicalize. Value cast = b.create<memref::CastOp>(funcOp.getLoc(), memrefType, bbArg); @@ -537,6 +542,10 @@ static void layoutPostProcessing(ModuleOp moduleOp) { // such cases. auto castArg = [&](Operation *caller) { OpBuilder b(caller); + assert( + memref::CastOp::areCastCompatible( + caller->getOperand(argNumber).getType(), desiredMemrefType) && + "layoutPostProcessing.2: cast incompatible"); Value newOperand = b.create<memref::CastOp>( funcOp.getLoc(), desiredMemrefType, caller->getOperand(argNumber)); operandsPerCaller.find(caller)->getSecond().push_back(newOperand); @@ -703,6 +712,9 @@ struct CallOpInterface // that will either canonicalize away or fail compilation until we can do // something better. if (buffer.getType() != memRefType) { + assert( + memref::CastOp::areCastCompatible(buffer.getType(), memRefType) && + "CallOp::bufferize: cast incompatible"); Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(), memRefType, buffer); buffer = castBuffer; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp index 7c9114b284b2..f0f20b433937 100644 --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -77,6 +77,9 @@ struct CastOpInterface } // Replace the op with a memref.cast. + assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), + resultMemRefType) && + "CallOp::bufferize: cast incompatible"); replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType, *resultBuffer); |