summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/MemRef
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/MemRef')
-rw-r--r--mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp2
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp8
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp2
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp8
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp38
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp4
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp2
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp8
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp10
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp18
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp8
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp8
12 files changed, 55 insertions, 61 deletions
diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index 7f702e197854..ae2472db4f86 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -69,7 +69,7 @@ DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply(
results.push_back(*newBuffer);
}
- transformResults.set(getResult().cast<OpResult>(), results);
+ transformResults.set(cast<OpResult>(getResult()), results);
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
index 369f22521895..9b1d85b29027 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
@@ -57,7 +57,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
// always 1.
if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) {
Attribute attr = valueOrAttr.dyn_cast<Attribute>();
- return attr && attr.cast<IntegerAttr>().getInt() == 1;
+ return attr && cast<IntegerAttr>(attr).getInt() == 1;
})) {
strides = SmallVector<OpFoldResult>(sourceOp.getMixedStrides().size(),
rewriter.getI64IntegerAttr(1));
@@ -93,8 +93,8 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
// If both offsets are static we can simply calculate the combined
// offset statically.
offsets.push_back(rewriter.getI64IntegerAttr(
- opOffsetAttr.cast<IntegerAttr>().getInt() +
- sourceOffsetAttr.cast<IntegerAttr>().getInt()));
+ cast<IntegerAttr>(opOffsetAttr).getInt() +
+ cast<IntegerAttr>(sourceOffsetAttr).getInt()));
} else {
// When either offset is dynamic, we must emit an additional affine
// transformation to add the two offsets together dynamically.
@@ -102,7 +102,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
SmallVector<Value> affineApplyOperands;
for (auto valueOrAttr : {opOffset, sourceOffset}) {
if (auto attr = valueOrAttr.dyn_cast<Attribute>()) {
- expr = expr + attr.cast<IntegerAttr>().getInt();
+ expr = expr + cast<IntegerAttr>(attr).getInt();
} else {
expr =
expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size());
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
index 6202b5730c21..57f0141c95dc 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
@@ -149,7 +149,7 @@ void memref::populateMemRefWideIntEmulationConversions(
arith::WideIntEmulationConverter &typeConverter) {
typeConverter.addConversion(
[&typeConverter](MemRefType ty) -> std::optional<Type> {
- auto intTy = ty.getElementType().dyn_cast<IntegerType>();
+ auto intTy = dyn_cast<IntegerType>(ty.getElementType());
if (!intTy)
return ty;
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
index 38fb11348f28..8a276ebbff6a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
@@ -89,11 +89,11 @@ public:
LogicalResult matchAndRewrite(memref::ReshapeOp op,
PatternRewriter &rewriter) const final {
- auto shapeType = op.getShape().getType().cast<MemRefType>();
+ auto shapeType = cast<MemRefType>(op.getShape().getType());
if (!shapeType.hasStaticShape())
return failure();
- int64_t rank = shapeType.cast<MemRefType>().getDimSize(0);
+ int64_t rank = cast<MemRefType>(shapeType).getDimSize(0);
SmallVector<OpFoldResult, 4> sizes, strides;
sizes.resize(rank);
strides.resize(rank);
@@ -106,7 +106,7 @@ public:
if (op.getType().isDynamicDim(i)) {
Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
size = rewriter.create<memref::LoadOp>(loc, op.getShape(), index);
- if (!size.getType().isa<IndexType>())
+ if (!isa<IndexType>(size.getType()))
size = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), size);
sizes[i] = size;
@@ -141,7 +141,7 @@ struct ExpandOpsPass : public memref::impl::ExpandOpsBase<ExpandOpsPass> {
op.getKind() != arith::AtomicRMWKind::minf;
});
target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
- return !op.getShape().getType().cast<MemRefType>().hasStaticShape();
+ return !cast<MemRefType>(op.getShape().getType()).hasStaticShape();
});
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index ea372bffbc0b..ff2c4107ee46 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -62,7 +62,7 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter,
// Build a plain extract_strided_metadata(memref) from subview(memref).
Location origLoc = subview.getLoc();
Value source = subview.getSource();
- auto sourceType = source.getType().cast<MemRefType>();
+ auto sourceType = cast<MemRefType>(source.getType());
unsigned sourceRank = sourceType.getRank();
auto newExtractStridedMetadata =
@@ -115,7 +115,7 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter,
// The final result is <baseBuffer, offset, sizes, strides>.
// Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all
// the values.
- auto subType = subview.getType().cast<MemRefType>();
+ auto subType = cast<MemRefType>(subview.getType());
unsigned subRank = subType.getRank();
// The sizes of the final type are defined directly by the input sizes of
@@ -338,7 +338,7 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
// Collect the statically known information about the original stride.
Value source = expandShape.getSrc();
- auto sourceType = source.getType().cast<MemRefType>();
+ auto sourceType = cast<MemRefType>(source.getType());
auto [strides, offset] = getStridesAndOffset(sourceType);
OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
@@ -358,10 +358,9 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
AffineExpr s0 = builder.getAffineSymbolExpr(0);
AffineExpr s1 = builder.getAffineSymbolExpr(1);
for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
- int64_t baseExpandedStride = expandedStrides[doneStrideIdx]
- .get<Attribute>()
- .cast<IntegerAttr>()
- .getInt();
+ int64_t baseExpandedStride =
+ cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>())
+ .getInt();
expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
builder, expandShape.getLoc(),
(s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
@@ -372,10 +371,9 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
// Now apply the origStride to the remaining dimensions.
AffineExpr s0 = builder.getAffineSymbolExpr(0);
for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
- int64_t baseExpandedStride = expandedStrides[doneStrideIdx]
- .get<Attribute>()
- .cast<IntegerAttr>()
- .getInt();
+ int64_t baseExpandedStride =
+ cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>())
+ .getInt();
expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
}
@@ -445,7 +443,7 @@ getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
// Build the affine expr of the product of the original sizes involved in that
// group.
Value source = collapseShape.getSrc();
- auto sourceType = source.getType().cast<MemRefType>();
+ auto sourceType = cast<MemRefType>(source.getType());
SmallVector<int64_t, 2> reassocGroup =
collapseShape.getReassociationIndices()[groupId];
@@ -479,7 +477,7 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
"Reassociation group should have at least one dimension");
Value source = collapseShape.getSrc();
- auto sourceType = source.getType().cast<MemRefType>();
+ auto sourceType = cast<MemRefType>(source.getType());
auto [strides, offset] = getStridesAndOffset(sourceType);
@@ -562,7 +560,7 @@ public:
// extract_strided_metadata(reassociative_reshape_like(memref)).
Location origLoc = reshape.getLoc();
Value source = reshape.getSrc();
- auto sourceType = source.getType().cast<MemRefType>();
+ auto sourceType = cast<MemRefType>(source.getType());
unsigned sourceRank = sourceType.getRank();
auto newExtractStridedMetadata =
@@ -650,8 +648,7 @@ public:
if (!allocLikeOp)
return failure();
- auto memRefType =
- allocLikeOp.getResult().getType().template cast<MemRefType>();
+ auto memRefType = cast<MemRefType>(allocLikeOp.getResult().getType());
if (!memRefType.getLayout().isIdentity())
return rewriter.notifyMatchFailure(
allocLikeOp, "alloc-like operations should have been normalized");
@@ -688,7 +685,7 @@ public:
SmallVector<Value> results;
results.reserve(rank * 2 + 2);
- auto baseBufferType = op.getBaseBuffer().getType().cast<MemRefType>();
+ auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
int64_t offset = 0;
if (allocLikeOp.getType() == baseBufferType)
results.push_back(allocLikeOp);
@@ -737,7 +734,7 @@ public:
if (!getGlobalOp)
return failure();
- auto memRefType = getGlobalOp.getResult().getType().cast<MemRefType>();
+ auto memRefType = cast<MemRefType>(getGlobalOp.getResult().getType());
if (!memRefType.getLayout().isIdentity()) {
return rewriter.notifyMatchFailure(
getGlobalOp,
@@ -759,7 +756,7 @@ public:
SmallVector<Value> results;
results.reserve(rank * 2 + 2);
- auto baseBufferType = op.getBaseBuffer().getType().cast<MemRefType>();
+ auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
int64_t offset = 0;
if (getGlobalOp.getType() == baseBufferType)
results.push_back(getGlobalOp);
@@ -838,8 +835,7 @@ class ExtractStridedMetadataOpReinterpretCastFolder
return rewriter.notifyMatchFailure(
reinterpretCastOp, "reinterpret_cast source's type is incompatible");
- auto memrefType =
- reinterpretCastOp.getResult().getType().cast<MemRefType>();
+ auto memrefType = cast<MemRefType>(reinterpretCastOp.getResult().getType());
unsigned rank = memrefType.getRank();
SmallVector<OpFoldResult> results;
results.resize_for_overwrite(rank * 2 + 2);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
index 5141b5f33cfa..05ba6a3f3870 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
@@ -120,7 +120,7 @@ template <typename TransferLikeOp>
static FailureOr<Value>
getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) {
Value src = transferLikeOp.getSource();
- if (src.getType().isa<MemRefType>())
+ if (isa<MemRefType>(src.getType()))
return src;
return failure();
}
@@ -240,7 +240,7 @@ struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> {
return rewriter.notifyMatchFailure(loadStoreLikeOp,
"source is not a memref");
Value srcMemRef = *failureOrSrcMemRef;
- auto ldStTy = srcMemRef.getType().cast<MemRefType>();
+ auto ldStTy = cast<MemRefType>(srcMemRef.getType());
unsigned loadStoreRank = ldStTy.getRank();
// Don't waste compile time if there is nothing to rewrite.
if (loadStoreRank == 0)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 72675b03abf6..2c30e98dd107 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -148,7 +148,7 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
if (collapseShapeOp.getReassociationIndices().empty()) {
auto zeroAffineMap = rewriter.getConstantAffineMap(0);
int64_t srcRank =
- collapseShapeOp.getViewSource().getType().cast<MemRefType>().getRank();
+ cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
for (int64_t i = 0; i < srcRank; i++) {
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
rewriter, loc, zeroAffineMap, dynamicIndices);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
index aa1d27dc863e..68b72eff8c97 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
@@ -71,11 +71,9 @@ propagateSubViewOp(RewriterBase &rewriter,
UnrealizedConversionCastOp conversionOp, SubViewOp op) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(op);
- auto newResultType =
- SubViewOp::inferRankReducedResultType(
- op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(),
- op.getMixedSizes(), op.getMixedStrides())
- .cast<MemRefType>();
+ auto newResultType = cast<MemRefType>(SubViewOp::inferRankReducedResultType(
+ op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(),
+ op.getMixedSizes(), op.getMixedStrides()));
Value newSubview = rewriter.create<SubViewOp>(
op.getLoc(), newResultType, conversionOp.getOperand(0),
op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides());
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
index ee1adcce80e5..eb1df2a87b99 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
@@ -61,11 +61,11 @@ static void replaceUsesAndPropagateType(RewriterBase &rewriter,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(subviewUse);
Type newType = memref::SubViewOp::inferRankReducedResultType(
- subviewUse.getType().getShape(), val.getType().cast<MemRefType>(),
+ subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
subviewUse.getStaticStrides());
Value newSubview = rewriter.create<memref::SubViewOp>(
- subviewUse->getLoc(), newType.cast<MemRefType>(), val,
+ subviewUse->getLoc(), cast<MemRefType>(newType), val,
subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
subviewUse.getMixedStrides());
@@ -209,9 +209,9 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
for (int64_t i = 0, e = originalShape.size(); i != e; ++i)
sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]);
// Strides is [1, 1 ... 1 ].
- auto dstMemref = memref::SubViewOp::inferRankReducedResultType(
- originalShape, mbMemRefType, offsets, sizes, strides)
- .cast<MemRefType>();
+ auto dstMemref =
+ cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
+ originalShape, mbMemRefType, offsets, sizes, strides));
Value subview = rewriter.create<memref::SubViewOp>(loc, dstMemref, mbAlloc,
offsets, sizes, strides);
LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");
diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
index c252433d16fa..aa21497fad8f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
@@ -180,7 +180,7 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
llvm::seq<unsigned>(0, callOp.getNumResults())) {
Value oldMemRef = callOp.getResult(resIndex);
if (auto oldMemRefType =
- oldMemRef.getType().dyn_cast<MemRefType>())
+ dyn_cast<MemRefType>(oldMemRef.getType()))
if (!oldMemRefType.getLayout().isIdentity() &&
!isMemRefNormalizable(oldMemRef.getUsers()))
return WalkResult::interrupt();
@@ -192,7 +192,7 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
BlockArgument oldMemRef = funcOp.getArgument(argIndex);
- if (auto oldMemRefType = oldMemRef.getType().dyn_cast<MemRefType>())
+ if (auto oldMemRefType = dyn_cast<MemRefType>(oldMemRef.getType()))
if (!oldMemRefType.getLayout().isIdentity() &&
!isMemRefNormalizable(oldMemRef.getUsers()))
return false;
@@ -226,7 +226,7 @@ void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp,
funcOp.walk([&](func::ReturnOp returnOp) {
for (const auto &operandEn : llvm::enumerate(returnOp.getOperands())) {
Type opType = operandEn.value().getType();
- MemRefType memrefType = opType.dyn_cast<MemRefType>();
+ MemRefType memrefType = dyn_cast<MemRefType>(opType);
// If type is not memref or if the memref type is same as that in
// function's return signature then no update is required.
if (!memrefType || memrefType == resultTypes[operandEn.index()])
@@ -284,7 +284,7 @@ void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp,
if (oldResult.getType() == newResult.getType())
continue;
AffineMap layoutMap =
- oldResult.getType().cast<MemRefType>().getLayout().getAffineMap();
+ cast<MemRefType>(oldResult.getType()).getLayout().getAffineMap();
if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult,
/*extraIndices=*/{},
/*indexRemap=*/layoutMap,
@@ -358,7 +358,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
for (unsigned argIndex :
llvm::seq<unsigned>(0, functionType.getNumInputs())) {
Type argType = functionType.getInput(argIndex);
- MemRefType memrefType = argType.dyn_cast<MemRefType>();
+ MemRefType memrefType = dyn_cast<MemRefType>(argType);
// Check whether argument is of MemRef type. Any other argument type can
// simply be part of the final function signature.
if (!memrefType) {
@@ -422,11 +422,11 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
// Replace all uses of the old memrefs.
Value oldMemRef = op->getResult(resIndex);
Value newMemRef = newOp->getResult(resIndex);
- MemRefType oldMemRefType = oldMemRef.getType().dyn_cast<MemRefType>();
+ MemRefType oldMemRefType = dyn_cast<MemRefType>(oldMemRef.getType());
// Check whether the operation result is MemRef type.
if (!oldMemRefType)
continue;
- MemRefType newMemRefType = newMemRef.getType().cast<MemRefType>();
+ MemRefType newMemRefType = cast<MemRefType>(newMemRef.getType());
if (oldMemRefType == newMemRefType)
continue;
// TODO: Assume single layout map. Multiple maps not supported.
@@ -466,7 +466,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
for (unsigned resIndex :
llvm::seq<unsigned>(0, functionType.getNumResults())) {
Type resType = functionType.getResult(resIndex);
- MemRefType memrefType = resType.dyn_cast<MemRefType>();
+ MemRefType memrefType = dyn_cast<MemRefType>(resType);
// Check whether result is of MemRef type. Any other argument type can
// simply be part of the final function signature.
if (!memrefType) {
@@ -507,7 +507,7 @@ Operation *NormalizeMemRefs::createOpResultsNormalized(func::FuncOp funcOp,
bool resultTypeNormalized = false;
for (unsigned resIndex : llvm::seq<unsigned>(0, oldOp->getNumResults())) {
auto resultType = oldOp->getResult(resIndex).getType();
- MemRefType memrefType = resultType.dyn_cast<MemRefType>();
+ MemRefType memrefType = dyn_cast<MemRefType>(resultType);
// Check whether the operation result is MemRef type.
if (!memrefType) {
resultTypes.push_back(resultType);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 8c544bbd9fb0..526c1c6e198f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -40,7 +40,7 @@ struct DimOfShapedTypeOpInterface : public OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy dimOp,
PatternRewriter &rewriter) const override {
- OpResult dimValue = dimOp.getSource().template dyn_cast<OpResult>();
+ OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
if (!dimValue)
return failure();
auto shapedTypeOp =
@@ -61,8 +61,8 @@ struct DimOfShapedTypeOpInterface : public OpRewritePattern<OpTy> {
return failure();
Value resultShape = reifiedResultShapes[dimValue.getResultNumber()];
- auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
- if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
+ auto resultShapeType = dyn_cast<RankedTensorType>(resultShape.getType());
+ if (!resultShapeType || !isa<IndexType>(resultShapeType.getElementType()))
return failure();
Location loc = dimOp->getLoc();
@@ -82,7 +82,7 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy dimOp,
PatternRewriter &rewriter) const override {
- OpResult dimValue = dimOp.getSource().template dyn_cast<OpResult>();
+ OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
if (!dimValue)
return failure();
std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 9ffb315587e3..05a069d98ef3 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -38,14 +38,14 @@ struct CastOpInterface
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
Location loc) const {
auto castOp = cast<CastOp>(op);
- auto srcType = castOp.getSource().getType().cast<BaseMemRefType>();
+ auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
// Nothing to check if the result is an unranked memref.
- auto resultType = castOp.getType().dyn_cast<MemRefType>();
+ auto resultType = dyn_cast<MemRefType>(castOp.getType());
if (!resultType)
return;
- if (srcType.isa<UnrankedMemRefType>()) {
+ if (isa<UnrankedMemRefType>(srcType)) {
// Check rank.
Value srcRank = builder.create<RankOp>(loc, castOp.getSource());
Value resultRank =
@@ -75,7 +75,7 @@ struct CastOpInterface
// Check dimension sizes.
for (const auto &it : llvm::enumerate(resultType.getShape())) {
// Static dim size -> static/dynamic dim size does not need verification.
- if (auto rankedSrcType = srcType.dyn_cast<MemRefType>())
+ if (auto rankedSrcType = dyn_cast<MemRefType>(srcType))
if (!rankedSrcType.isDynamicDim(it.index()))
continue;