diff options
Diffstat (limited to 'mlir/lib/Dialect/MemRef')
12 files changed, 55 insertions, 61 deletions
diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp index 7f702e197854..ae2472db4f86 100644 --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -69,7 +69,7 @@ DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply( results.push_back(*newBuffer); } - transformResults.set(getResult().cast<OpResult>(), results); + transformResults.set(cast<OpResult>(getResult()), results); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp index 369f22521895..9b1d85b29027 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp @@ -57,7 +57,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> { // always 1. if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) { Attribute attr = valueOrAttr.dyn_cast<Attribute>(); - return attr && attr.cast<IntegerAttr>().getInt() == 1; + return attr && cast<IntegerAttr>(attr).getInt() == 1; })) { strides = SmallVector<OpFoldResult>(sourceOp.getMixedStrides().size(), rewriter.getI64IntegerAttr(1)); @@ -93,8 +93,8 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> { // If both offsets are static we can simply calculate the combined // offset statically. offsets.push_back(rewriter.getI64IntegerAttr( - opOffsetAttr.cast<IntegerAttr>().getInt() + - sourceOffsetAttr.cast<IntegerAttr>().getInt())); + cast<IntegerAttr>(opOffsetAttr).getInt() + + cast<IntegerAttr>(sourceOffsetAttr).getInt())); } else { // When either offset is dynamic, we must emit an additional affine // transformation to add the two offsets together dynamically. @@ -102,7 +102,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> { SmallVector<Value> affineApplyOperands; for (auto valueOrAttr : {opOffset, sourceOffset}) { if (auto attr = valueOrAttr.dyn_cast<Attribute>()) { - expr = expr + attr.cast<IntegerAttr>().getInt(); + expr = expr + cast<IntegerAttr>(attr).getInt(); } else { expr = expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size()); diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp index 6202b5730c21..57f0141c95dc 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp @@ -149,7 +149,7 @@ void memref::populateMemRefWideIntEmulationConversions( arith::WideIntEmulationConverter &typeConverter) { typeConverter.addConversion( [&typeConverter](MemRefType ty) -> std::optional<Type> { - auto intTy = ty.getElementType().dyn_cast<IntegerType>(); + auto intTy = dyn_cast<IntegerType>(ty.getElementType()); if (!intTy) return ty; diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp index 38fb11348f28..8a276ebbff6a 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp @@ -89,11 +89,11 @@ public: LogicalResult matchAndRewrite(memref::ReshapeOp op, PatternRewriter &rewriter) const final { - auto shapeType = op.getShape().getType().cast<MemRefType>(); + auto shapeType = cast<MemRefType>(op.getShape().getType()); if (!shapeType.hasStaticShape()) return failure(); - int64_t rank = shapeType.cast<MemRefType>().getDimSize(0); + int64_t rank = cast<MemRefType>(shapeType).getDimSize(0); SmallVector<OpFoldResult, 4> sizes, strides; sizes.resize(rank); strides.resize(rank); @@ -106,7 +106,7 @@ public: if (op.getType().isDynamicDim(i)) { Value index = rewriter.create<arith::ConstantIndexOp>(loc, i); size = rewriter.create<memref::LoadOp>(loc, op.getShape(), index); - if (!size.getType().isa<IndexType>()) + if (!isa<IndexType>(size.getType())) size = rewriter.create<arith::IndexCastOp>( loc, rewriter.getIndexType(), size); sizes[i] = size; @@ -141,7 +141,7 @@ struct ExpandOpsPass : public memref::impl::ExpandOpsBase<ExpandOpsPass> { op.getKind() != arith::AtomicRMWKind::minf; }); target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) { - return !op.getShape().getType().cast<MemRefType>().hasStaticShape(); + return !cast<MemRefType>(op.getShape().getType()).hasStaticShape(); }); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index ea372bffbc0b..ff2c4107ee46 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -62,7 +62,7 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter, // Build a plain extract_strided_metadata(memref) from subview(memref). Location origLoc = subview.getLoc(); Value source = subview.getSource(); - auto sourceType = source.getType().cast<MemRefType>(); + auto sourceType = cast<MemRefType>(source.getType()); unsigned sourceRank = sourceType.getRank(); auto newExtractStridedMetadata = @@ -115,7 +115,7 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter, // The final result is <baseBuffer, offset, sizes, strides>. // Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all // the values. - auto subType = subview.getType().cast<MemRefType>(); + auto subType = cast<MemRefType>(subview.getType()); unsigned subRank = subType.getRank(); // The sizes of the final type are defined directly by the input sizes of @@ -338,7 +338,7 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape, // Collect the statically known information about the original stride. Value source = expandShape.getSrc(); - auto sourceType = source.getType().cast<MemRefType>(); + auto sourceType = cast<MemRefType>(source.getType()); auto [strides, offset] = getStridesAndOffset(sourceType); OpFoldResult origStride = ShapedType::isDynamic(strides[groupId]) @@ -358,10 +358,9 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape, AffineExpr s0 = builder.getAffineSymbolExpr(0); AffineExpr s1 = builder.getAffineSymbolExpr(1); for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) { - int64_t baseExpandedStride = expandedStrides[doneStrideIdx] - .get<Attribute>() - .cast<IntegerAttr>() - .getInt(); + int64_t baseExpandedStride = + cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>()) + .getInt(); expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply( builder, expandShape.getLoc(), (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1, @@ -372,10 +371,9 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape, // Now apply the origStride to the remaining dimensions. AffineExpr s0 = builder.getAffineSymbolExpr(0); for (; doneStrideIdx < groupSize; ++doneStrideIdx) { - int64_t baseExpandedStride = expandedStrides[doneStrideIdx] - .get<Attribute>() - .cast<IntegerAttr>() - .getInt(); + int64_t baseExpandedStride = + cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>()) + .getInt(); expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply( builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride}); } @@ -445,7 +443,7 @@ getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder, // Build the affine expr of the product of the original sizes involved in that // group. Value source = collapseShape.getSrc(); - auto sourceType = source.getType().cast<MemRefType>(); + auto sourceType = cast<MemRefType>(source.getType()); SmallVector<int64_t, 2> reassocGroup = collapseShape.getReassociationIndices()[groupId]; @@ -479,7 +477,7 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder, "Reassociation group should have at least one dimension"); Value source = collapseShape.getSrc(); - auto sourceType = source.getType().cast<MemRefType>(); + auto sourceType = cast<MemRefType>(source.getType()); auto [strides, offset] = getStridesAndOffset(sourceType); @@ -562,7 +560,7 @@ public: // extract_strided_metadata(reassociative_reshape_like(memref)). Location origLoc = reshape.getLoc(); Value source = reshape.getSrc(); - auto sourceType = source.getType().cast<MemRefType>(); + auto sourceType = cast<MemRefType>(source.getType()); unsigned sourceRank = sourceType.getRank(); auto newExtractStridedMetadata = @@ -650,8 +648,7 @@ public: if (!allocLikeOp) return failure(); - auto memRefType = - allocLikeOp.getResult().getType().template cast<MemRefType>(); + auto memRefType = cast<MemRefType>(allocLikeOp.getResult().getType()); if (!memRefType.getLayout().isIdentity()) return rewriter.notifyMatchFailure( allocLikeOp, "alloc-like operations should have been normalized"); @@ -688,7 +685,7 @@ public: SmallVector<Value> results; results.reserve(rank * 2 + 2); - auto baseBufferType = op.getBaseBuffer().getType().cast<MemRefType>(); + auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType()); int64_t offset = 0; if (allocLikeOp.getType() == baseBufferType) results.push_back(allocLikeOp); @@ -737,7 +734,7 @@ public: if (!getGlobalOp) return failure(); - auto memRefType = getGlobalOp.getResult().getType().cast<MemRefType>(); + auto memRefType = cast<MemRefType>(getGlobalOp.getResult().getType()); if (!memRefType.getLayout().isIdentity()) { return rewriter.notifyMatchFailure( getGlobalOp, @@ -759,7 +756,7 @@ public: SmallVector<Value> results; results.reserve(rank * 2 + 2); - auto baseBufferType = op.getBaseBuffer().getType().cast<MemRefType>(); + auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType()); int64_t offset = 0; if (getGlobalOp.getType() == baseBufferType) results.push_back(getGlobalOp); @@ -838,8 +835,7 @@ class ExtractStridedMetadataOpReinterpretCastFolder return rewriter.notifyMatchFailure( reinterpretCastOp, "reinterpret_cast source's type is incompatible"); - auto memrefType = - reinterpretCastOp.getResult().getType().cast<MemRefType>(); + auto memrefType = cast<MemRefType>(reinterpretCastOp.getResult().getType()); unsigned rank = memrefType.getRank(); SmallVector<OpFoldResult> results; results.resize_for_overwrite(rank * 2 + 2); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp index 5141b5f33cfa..05ba6a3f3870 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp @@ -120,7 +120,7 @@ template <typename TransferLikeOp> static FailureOr<Value> getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) { Value src = transferLikeOp.getSource(); - if (src.getType().isa<MemRefType>()) + if (isa<MemRefType>(src.getType())) return src; return failure(); } @@ -240,7 +240,7 @@ struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> { return rewriter.notifyMatchFailure(loadStoreLikeOp, "source is not a memref"); Value srcMemRef = *failureOrSrcMemRef; - auto ldStTy = srcMemRef.getType().cast<MemRefType>(); + auto ldStTy = cast<MemRefType>(srcMemRef.getType()); unsigned loadStoreRank = ldStTy.getRank(); // Don't waste compile time if there is nothing to rewrite. if (loadStoreRank == 0) diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp index 72675b03abf6..2c30e98dd107 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -148,7 +148,7 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, if (collapseShapeOp.getReassociationIndices().empty()) { auto zeroAffineMap = rewriter.getConstantAffineMap(0); int64_t srcRank = - collapseShapeOp.getViewSource().getType().cast<MemRefType>().getRank(); + cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank(); for (int64_t i = 0; i < srcRank; i++) { OpFoldResult ofr = affine::makeComposedFoldedAffineApply( rewriter, loc, zeroAffineMap, dynamicIndices); diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp index aa1d27dc863e..68b72eff8c97 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp @@ -71,11 +71,9 @@ propagateSubViewOp(RewriterBase &rewriter, UnrealizedConversionCastOp conversionOp, SubViewOp op) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); - auto newResultType = - SubViewOp::inferRankReducedResultType( - op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(), - op.getMixedSizes(), op.getMixedStrides()) - .cast<MemRefType>(); + auto newResultType = cast<MemRefType>(SubViewOp::inferRankReducedResultType( + op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(), + op.getMixedSizes(), op.getMixedStrides())); Value newSubview = rewriter.create<SubViewOp>( op.getLoc(), newResultType, conversionOp.getOperand(0), op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides()); diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp index ee1adcce80e5..eb1df2a87b99 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp @@ -61,11 +61,11 @@ static void replaceUsesAndPropagateType(RewriterBase &rewriter, OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(subviewUse); Type newType = memref::SubViewOp::inferRankReducedResultType( - subviewUse.getType().getShape(), val.getType().cast<MemRefType>(), + subviewUse.getType().getShape(), cast<MemRefType>(val.getType()), subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(), subviewUse.getStaticStrides()); Value newSubview = rewriter.create<memref::SubViewOp>( - subviewUse->getLoc(), newType.cast<MemRefType>(), val, + subviewUse->getLoc(), cast<MemRefType>(newType), val, subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(), subviewUse.getMixedStrides()); @@ -209,9 +209,9 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, for (int64_t i = 0, e = originalShape.size(); i != e; ++i) sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]); // Strides is [1, 1 ... 1 ]. - auto dstMemref = memref::SubViewOp::inferRankReducedResultType( - originalShape, mbMemRefType, offsets, sizes, strides) - .cast<MemRefType>(); + auto dstMemref = + cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType( + originalShape, mbMemRefType, offsets, sizes, strides)); Value subview = rewriter.create<memref::SubViewOp>(loc, dstMemref, mbAlloc, offsets, sizes, strides); LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n"); diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp index c252433d16fa..aa21497fad8f 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp @@ -180,7 +180,7 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) { llvm::seq<unsigned>(0, callOp.getNumResults())) { Value oldMemRef = callOp.getResult(resIndex); if (auto oldMemRefType = - oldMemRef.getType().dyn_cast<MemRefType>()) + dyn_cast<MemRefType>(oldMemRef.getType())) if (!oldMemRefType.getLayout().isIdentity() && !isMemRefNormalizable(oldMemRef.getUsers())) return WalkResult::interrupt(); @@ -192,7 +192,7 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) { for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) { BlockArgument oldMemRef = funcOp.getArgument(argIndex); - if (auto oldMemRefType = oldMemRef.getType().dyn_cast<MemRefType>()) + if (auto oldMemRefType = dyn_cast<MemRefType>(oldMemRef.getType())) if (!oldMemRefType.getLayout().isIdentity() && !isMemRefNormalizable(oldMemRef.getUsers())) return false; @@ -226,7 +226,7 @@ void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp, funcOp.walk([&](func::ReturnOp returnOp) { for (const auto &operandEn : llvm::enumerate(returnOp.getOperands())) { Type opType = operandEn.value().getType(); - MemRefType memrefType = opType.dyn_cast<MemRefType>(); + MemRefType memrefType = dyn_cast<MemRefType>(opType); // If type is not memref or if the memref type is same as that in // function's return signature then no update is required. if (!memrefType || memrefType == resultTypes[operandEn.index()]) @@ -284,7 +284,7 @@ void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp, if (oldResult.getType() == newResult.getType()) continue; AffineMap layoutMap = - oldResult.getType().cast<MemRefType>().getLayout().getAffineMap(); + cast<MemRefType>(oldResult.getType()).getLayout().getAffineMap(); if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult, /*extraIndices=*/{}, /*indexRemap=*/layoutMap, @@ -358,7 +358,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp, for (unsigned argIndex : llvm::seq<unsigned>(0, functionType.getNumInputs())) { Type argType = functionType.getInput(argIndex); - MemRefType memrefType = argType.dyn_cast<MemRefType>(); + MemRefType memrefType = dyn_cast<MemRefType>(argType); // Check whether argument is of MemRef type. Any other argument type can // simply be part of the final function signature. if (!memrefType) { @@ -422,11 +422,11 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp, // Replace all uses of the old memrefs. Value oldMemRef = op->getResult(resIndex); Value newMemRef = newOp->getResult(resIndex); - MemRefType oldMemRefType = oldMemRef.getType().dyn_cast<MemRefType>(); + MemRefType oldMemRefType = dyn_cast<MemRefType>(oldMemRef.getType()); // Check whether the operation result is MemRef type. if (!oldMemRefType) continue; - MemRefType newMemRefType = newMemRef.getType().cast<MemRefType>(); + MemRefType newMemRefType = cast<MemRefType>(newMemRef.getType()); if (oldMemRefType == newMemRefType) continue; // TODO: Assume single layout map. Multiple maps not supported. @@ -466,7 +466,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp, for (unsigned resIndex : llvm::seq<unsigned>(0, functionType.getNumResults())) { Type resType = functionType.getResult(resIndex); - MemRefType memrefType = resType.dyn_cast<MemRefType>(); + MemRefType memrefType = dyn_cast<MemRefType>(resType); // Check whether result is of MemRef type. Any other argument type can // simply be part of the final function signature. if (!memrefType) { @@ -507,7 +507,7 @@ Operation *NormalizeMemRefs::createOpResultsNormalized(func::FuncOp funcOp, bool resultTypeNormalized = false; for (unsigned resIndex : llvm::seq<unsigned>(0, oldOp->getNumResults())) { auto resultType = oldOp->getResult(resIndex).getType(); - MemRefType memrefType = resultType.dyn_cast<MemRefType>(); + MemRefType memrefType = dyn_cast<MemRefType>(resultType); // Check whether the operation result is MemRef type. if (!memrefType) { resultTypes.push_back(resultType); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp index 8c544bbd9fb0..526c1c6e198f 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp @@ -40,7 +40,7 @@ struct DimOfShapedTypeOpInterface : public OpRewritePattern<OpTy> { LogicalResult matchAndRewrite(OpTy dimOp, PatternRewriter &rewriter) const override { - OpResult dimValue = dimOp.getSource().template dyn_cast<OpResult>(); + OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource()); if (!dimValue) return failure(); auto shapedTypeOp = @@ -61,8 +61,8 @@ struct DimOfShapedTypeOpInterface : public OpRewritePattern<OpTy> { return failure(); Value resultShape = reifiedResultShapes[dimValue.getResultNumber()]; - auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>(); - if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>()) + auto resultShapeType = dyn_cast<RankedTensorType>(resultShape.getType()); + if (!resultShapeType || !isa<IndexType>(resultShapeType.getElementType())) return failure(); Location loc = dimOp->getLoc(); @@ -82,7 +82,7 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> { LogicalResult matchAndRewrite(OpTy dimOp, PatternRewriter &rewriter) const override { - OpResult dimValue = dimOp.getSource().template dyn_cast<OpResult>(); + OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource()); if (!dimValue) return failure(); std::optional<int64_t> dimIndex = dimOp.getConstantIndex(); diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp index 9ffb315587e3..05a069d98ef3 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -38,14 +38,14 @@ struct CastOpInterface void generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc) const { auto castOp = cast<CastOp>(op); - auto srcType = castOp.getSource().getType().cast<BaseMemRefType>(); + auto srcType = cast<BaseMemRefType>(castOp.getSource().getType()); // Nothing to check if the result is an unranked memref. - auto resultType = castOp.getType().dyn_cast<MemRefType>(); + auto resultType = dyn_cast<MemRefType>(castOp.getType()); if (!resultType) return; - if (srcType.isa<UnrankedMemRefType>()) { + if (isa<UnrankedMemRefType>(srcType)) { // Check rank. Value srcRank = builder.create<RankOp>(loc, castOp.getSource()); Value resultRank = @@ -75,7 +75,7 @@ struct CastOpInterface // Check dimension sizes. for (const auto &it : llvm::enumerate(resultType.getShape())) { // Static dim size -> static/dynamic dim size does not need verification. - if (auto rankedSrcType = srcType.dyn_cast<MemRefType>()) + if (auto rankedSrcType = dyn_cast<MemRefType>(srcType)) if (!rankedSrcType.isDynamicDim(it.index())) continue; |