summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
authorTres Popp <tpopp@google.com>2023-05-08 16:33:54 +0200
committerTres Popp <tpopp@google.com>2023-05-12 11:21:25 +0200
commit5550c821897ab77e664977121a0e90ad5be1ff59 (patch)
tree1947e879997b2fccdb629789362c42310d1f1f84 /mlir/lib/Dialect
parent5c8ce6d5761ed6a9a39ef5a77aa45d8b6095e0f5 (diff)
downloadllvm-5550c821897ab77e664977121a0e90ad5be1ff59.tar.gz
[mlir] Move casting calls from methods to function calls
The MLIR classes Type/Attribute/Operation/Op/Value support cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast functionality in addition to defining methods with the same name. This change begins the migration of uses of the method to the corresponding function call as has been decided as more consistent. Note that there still exist classes that only define methods directly, such as AffineExpr, and this does not include work currently to support a functional cast/isa call. Caveats include: - This clang-tidy script probably has more problems. - This only touches C++ code, so nothing that is being generated. Context: - https://mlir.llvm.org/deprecation/ at "Use the free function variants for dyn_cast/cast/isa/…" - Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443 Implementation: This first patch was created with the following steps. The intention is to only do automated changes at first, so I waste less time if it's reverted, and so the first mass change is more clear as an example to other teams that will need to follow similar steps. Steps are described per line, as comments are removed by git: 0. Retrieve the change from the following to build clang-tidy with an additional check: https://github.com/llvm/llvm-project/compare/main...tpopp:llvm-project:tidy-cast-check 1. Build clang-tidy 2. Run clang-tidy over your entire codebase while disabling all checks and enabling the one relevant one. Run on all header files also. 3. Delete .inc files that were also modified, so the next build rebuilds them to a pure state. 4. Some changes have been deleted for the following reasons: - Some files had a variable also named cast - Some files had not included a header file that defines the cast functions - Some files are definitions of the classes that have the casting methods, so the code still refers to the method instead of the function without adding a prefix or removing the method declaration at the same time. ``` ninja -C $BUILD_DIR clang-tidy run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\ -header-filter=mlir/ mlir/* -fix rm -rf $BUILD_DIR/tools/mlir/**/*.inc git restore mlir/lib/IR mlir/lib/Dialect/DLTI/DLTI.cpp\ mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp\ mlir/lib/**/IR/\ mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp\ mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp\ mlir/test/lib/Dialect/Test/TestTypes.cpp\ mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp\ mlir/test/lib/Dialect/Test/TestAttributes.cpp\ mlir/unittests/TableGen/EnumsGenTest.cpp\ mlir/test/python/lib/PythonTestCAPI.cpp\ mlir/include/mlir/IR/ ``` Differential Revision: https://reviews.llvm.org/D150123
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp4
-rw-r--r--mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp2
-rw-r--r--mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp4
-rw-r--r--mlir/lib/Dialect/Affine/Analysis/Utils.cpp24
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp6
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp2
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp4
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp8
-rw-r--r--mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp8
-rw-r--r--mlir/lib/Dialect/Affine/Utils/Utils.cpp20
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp12
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp39
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp6
-rw-r--r--mlir/lib/Dialect/Arith/Utils/Utils.cpp28
-rw-r--r--mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp6
-rw-r--r--mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp6
-rw-r--r--mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp2
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp6
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp4
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp10
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp8
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp14
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp6
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp18
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp42
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp14
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp6
-rw-r--r--mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp10
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp2
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp8
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp4
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp8
-rw-r--r--mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp6
-rw-r--r--mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp6
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp4
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp105
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp14
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp26
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp26
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp10
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp6
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp17
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp35
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp8
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp6
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp8
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp8
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp9
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Loops.cpp4
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp6
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp4
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Split.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp6
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp8
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp10
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp3
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp38
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp57
-rw-r--r--mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp10
-rw-r--r--mlir/lib/Dialect/Linalg/Utils/Utils.cpp18
-rw-r--r--mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp4
-rw-r--r--mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp16
-rw-r--r--mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp2
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp8
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp2
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp8
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp38
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp4
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp2
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp8
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp10
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp18
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp8
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp8
-rw-r--r--mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp2
-rw-r--r--mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp2
-rw-r--r--mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp2
-rw-r--r--mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp14
-rw-r--r--mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp6
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp80
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp6
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp5
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp4
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp4
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp14
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp12
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp59
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp4
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp36
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp4
-rw-r--r--mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp18
-rw-r--r--mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp4
-rw-r--r--mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp2
-rw-r--r--mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp2
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp36
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h6
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp4
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h4
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp8
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp6
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp12
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp10
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp8
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h4
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp12
-rw-r--r--mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp22
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp42
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp4
-rw-r--r--mlir/lib/Dialect/Tensor/Utils/Utils.cpp6
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp10
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp16
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp24
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp6
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp11
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp24
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp2
-rw-r--r--mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp34
-rw-r--r--mlir/lib/Dialect/Traits.cpp16
-rw-r--r--mlir/lib/Dialect/Transform/Transforms/CheckUses.cpp6
-rw-r--r--mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp4
-rw-r--r--mlir/lib/Dialect/Utils/StaticValueUtils.cpp12
-rw-r--r--mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp18
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp16
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp40
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp10
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp20
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp6
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp54
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp14
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp31
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp42
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp6
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp26
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp6
-rw-r--r--mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp8
-rw-r--r--mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp4
141 files changed, 924 insertions, 953 deletions
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
index 0c69cdc02791..d07d6518d57c 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
@@ -65,7 +65,7 @@ static void patchOperandSegmentSizes(ArrayRef<NamedAttribute> attrs,
newAttrs.push_back(attr);
continue;
}
- auto segmentAttr = attr.getValue().cast<DenseI32ArrayAttr>();
+ auto segmentAttr = cast<DenseI32ArrayAttr>(attr.getValue());
MLIRContext *context = segmentAttr.getContext();
DenseI32ArrayAttr newSegments;
switch (action) {
@@ -128,7 +128,7 @@ LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
Value prevLoadForCompare = prevLoad;
Value atomicResForCompare = atomicRes;
- if (auto floatDataTy = dataType.dyn_cast<FloatType>()) {
+ if (auto floatDataTy = dyn_cast<FloatType>(dataType)) {
Type equivInt = rewriter.getIntegerType(floatDataTy.getWidth());
prevLoadForCompare =
rewriter.create<arith::BitcastOp>(loc, equivInt, prevLoad);
diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
index 9b0ad3c119d0..4b3730a4aa39 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
@@ -136,7 +136,7 @@ static bool isLocallyDefined(Value v, Operation *enclosingOp) {
bool mlir::affine::isLoopMemoryParallel(AffineForOp forOp) {
// Any memref-typed iteration arguments are treated as serializing.
if (llvm::any_of(forOp.getResultTypes(),
- [](Type type) { return type.isa<BaseMemRefType>(); }))
+ [](Type type) { return isa<BaseMemRefType>(type); }))
return false;
// Collect all load and store ops in loop nest rooted at 'forOp'.
diff --git a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
index 9db1e998bb16..c97e99c0a0c1 100644
--- a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
@@ -162,7 +162,7 @@ uint64_t mlir::affine::getLargestDivisorOfTripCount(AffineForOp forOp) {
/// conservative.
static bool isAccessIndexInvariant(Value iv, Value index) {
assert(isAffineForInductionVar(iv) && "iv must be a AffineForOp");
- assert(index.getType().isa<IndexType>() && "index must be of IndexType");
+ assert(isa<IndexType>(index.getType()) && "index must be of IndexType");
SmallVector<Operation *, 4> affineApplyOps;
getReachableAffineApplyOps({index}, affineApplyOps);
@@ -262,7 +262,7 @@ static bool isContiguousAccess(Value iv, LoadOrStoreOp memoryOp,
template <typename LoadOrStoreOp>
static bool isVectorElement(LoadOrStoreOp memoryOp) {
auto memRefType = memoryOp.getMemRefType();
- return memRefType.getElementType().template isa<VectorType>();
+ return isa<VectorType>(memRefType.getElementType());
}
using VectorizableOpFun = std::function<bool(AffineForOp, Operation &)>;
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 438296f25096..4433d94eb145 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -190,7 +190,7 @@ void MemRefDependenceGraph::addEdge(unsigned srcId, unsigned dstId,
if (!hasEdge(srcId, dstId, value)) {
outEdges[srcId].push_back({dstId, value});
inEdges[dstId].push_back({srcId, value});
- if (value.getType().isa<MemRefType>())
+ if (isa<MemRefType>(value.getType()))
memrefEdgeCount[value]++;
}
}
@@ -200,7 +200,7 @@ void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId,
Value value) {
assert(inEdges.count(dstId) > 0);
assert(outEdges.count(srcId) > 0);
- if (value.getType().isa<MemRefType>()) {
+ if (isa<MemRefType>(value.getType())) {
assert(memrefEdgeCount.count(value) > 0);
memrefEdgeCount[value]--;
}
@@ -289,7 +289,7 @@ void MemRefDependenceGraph::gatherDefiningNodes(
// By definition of edge, if the edge value is a non-memref value,
// then the dependence is between a graph node which defines an SSA value
// and another graph node which uses the SSA value.
- if (!edge.value.getType().isa<MemRefType>())
+ if (!isa<MemRefType>(edge.value.getType()))
definingNodes.insert(edge.id);
}
@@ -473,7 +473,7 @@ void MemRefDependenceGraph::forEachMemRefEdge(
ArrayRef<Edge> edges, const std::function<void(Edge)> &callback) {
for (const auto &edge : edges) {
// Skip if 'edge' is not a memref dependence edge.
- if (!edge.value.getType().isa<MemRefType>())
+ if (!isa<MemRefType>(edge.value.getType()))
continue;
assert(nodes.count(edge.id) > 0);
// Skip if 'edge.id' is not a loop nest.
@@ -808,13 +808,13 @@ std::optional<bool> ComputationSliceState::isMaximal() const {
}
unsigned MemRefRegion::getRank() const {
- return memref.getType().cast<MemRefType>().getRank();
+ return cast<MemRefType>(memref.getType()).getRank();
}
std::optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape(
SmallVectorImpl<int64_t> *shape, std::vector<SmallVector<int64_t, 4>> *lbs,
SmallVectorImpl<int64_t> *lbDivisors) const {
- auto memRefType = memref.getType().cast<MemRefType>();
+ auto memRefType = cast<MemRefType>(memref.getType());
unsigned rank = memRefType.getRank();
if (shape)
shape->reserve(rank);
@@ -875,7 +875,7 @@ std::optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape(
void MemRefRegion::getLowerAndUpperBound(unsigned pos, AffineMap &lbMap,
AffineMap &ubMap) const {
assert(pos < cst.getNumDimVars() && "invalid position");
- auto memRefType = memref.getType().cast<MemRefType>();
+ auto memRefType = cast<MemRefType>(memref.getType());
unsigned rank = memRefType.getRank();
assert(rank == cst.getNumDimVars() && "inconsistent memref region");
@@ -1049,7 +1049,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
// to guard against potential over-approximation from projection.
// TODO: Support dynamic memref dimensions.
if (addMemRefDimBounds) {
- auto memRefType = memref.getType().cast<MemRefType>();
+ auto memRefType = cast<MemRefType>(memref.getType());
for (unsigned r = 0; r < rank; r++) {
cst.addBound(BoundType::LB, /*pos=*/r, /*value=*/0);
if (memRefType.isDynamicDim(r))
@@ -1071,7 +1071,7 @@ mlir::affine::getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType) {
unsigned sizeInBits;
if (elementType.isIntOrFloat()) {
sizeInBits = elementType.getIntOrFloatBitWidth();
- } else if (auto vectorType = elementType.dyn_cast<VectorType>()) {
+ } else if (auto vectorType = dyn_cast<VectorType>(elementType)) {
if (vectorType.getElementType().isIntOrFloat())
sizeInBits =
vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
@@ -1085,7 +1085,7 @@ mlir::affine::getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType) {
// Returns the size of the region.
std::optional<int64_t> MemRefRegion::getRegionSize() {
- auto memRefType = memref.getType().cast<MemRefType>();
+ auto memRefType = cast<MemRefType>(memref.getType());
if (!memRefType.getLayout().isIdentity()) {
LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
@@ -1119,7 +1119,7 @@ mlir::affine::getIntOrFloatMemRefSizeInBytes(MemRefType memRefType) {
if (!memRefType.hasStaticShape())
return std::nullopt;
auto elementType = memRefType.getElementType();
- if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>())
+ if (!elementType.isIntOrFloat() && !isa<VectorType>(elementType))
return std::nullopt;
auto sizeInBytes = getMemRefIntOrFloatEltSizeInBytes(memRefType);
@@ -1708,7 +1708,7 @@ MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) {
}
unsigned MemRefAccess::getRank() const {
- return memref.getType().cast<MemRefType>().getRank();
+ return cast<MemRefType>(memref.getType()).getRank();
}
bool MemRefAccess::isStore() const {
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 89f0a9e92279..2a9416f39f2f 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -289,7 +289,7 @@ bool MemRefDependenceGraph::init() {
// memref type. Call Op that returns one or more memref type results
// is already taken care of, by the previous conditions.
if (llvm::any_of(op.getOperandTypes(),
- [&](Type t) { return t.isa<MemRefType>(); })) {
+ [&](Type t) { return isa<MemRefType>(t); })) {
Node node(nextNodeId++, &op);
nodes.insert({node.id, node});
}
@@ -379,7 +379,7 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
OpBuilder top(forInst->getParentRegion());
// Create new memref type based on slice bounds.
auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef();
- auto oldMemRefType = oldMemRef.getType().cast<MemRefType>();
+ auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
unsigned rank = oldMemRefType.getRank();
// Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
@@ -516,7 +516,7 @@ static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
return WalkResult::advance();
for (Value v : op->getOperands())
// Collect memref values only.
- if (v.getType().isa<MemRefType>())
+ if (isa<MemRefType>(v.getType()))
memRefValues.insert(v);
return WalkResult::advance();
});
diff --git a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
index 7d815f742541..7029251a3720 100644
--- a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
@@ -88,7 +88,7 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
return MemRefType::Builder(oldMemRefType).setShape(newShape).setLayout({});
};
- auto oldMemRefType = oldMemRef.getType().cast<MemRefType>();
+ auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
auto newMemRefType = doubleShape(oldMemRefType);
// The double buffer is allocated right before 'forOp'.
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
index 8987a82b7206..49618074ec22 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
@@ -100,9 +100,9 @@ void SimplifyAffineStructures::runOnOperation() {
SmallVector<Operation *> opsToSimplify;
func.walk([&](Operation *op) {
for (auto attr : op->getAttrs()) {
- if (auto mapAttr = attr.getValue().dyn_cast<AffineMapAttr>())
+ if (auto mapAttr = dyn_cast<AffineMapAttr>(attr.getValue()))
simplifyAndUpdateAttribute(op, attr.getName(), mapAttr);
- else if (auto setAttr = attr.getValue().dyn_cast<IntegerSetAttr>())
+ else if (auto setAttr = dyn_cast<IntegerSetAttr>(attr.getValue()))
simplifyAndUpdateAttribute(op, attr.getName(), setAttr);
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
index 1d347329c000..b23a2cce35ce 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
@@ -838,7 +838,7 @@ void VectorizationState::registerValueVectorReplacementImpl(Value replaced,
Value replacement) {
assert(!valueVectorReplacement.contains(replaced) &&
"Vector replacement already registered");
- assert(replacement.getType().isa<VectorType>() &&
+ assert(isa<VectorType>(replacement.getType()) &&
"Expected vector type in vector replacement");
valueVectorReplacement.map(replaced, replacement);
}
@@ -883,7 +883,7 @@ void VectorizationState::registerValueScalarReplacementImpl(Value replaced,
Value replacement) {
assert(!valueScalarReplacement.contains(replaced) &&
"Scalar value replacement already registered");
- assert(!replacement.getType().isa<VectorType>() &&
+ assert(!isa<VectorType>(replacement.getType()) &&
"Expected scalar type in scalar replacement");
valueScalarReplacement.map(replaced, replacement);
}
@@ -946,7 +946,7 @@ isVectorizableLoopPtrFactory(const DenseSet<Operation *> &parallelLoops,
/// strategy on the scalar type.
static VectorType getVectorType(Type scalarTy,
const VectorizationStrategy *strategy) {
- assert(!scalarTy.isa<VectorType>() && "Expected scalar type");
+ assert(!isa<VectorType>(scalarTy) && "Expected scalar type");
return VectorType::get(strategy->vectorSizes, scalarTy);
}
@@ -1137,7 +1137,7 @@ static Value vectorizeOperand(Value operand, VectorizationState &state) {
// An vector operand that is not in the replacement map should never reach
// this point. Reaching this point could mean that the code was already
// vectorized and we shouldn't try to vectorize already vectorized code.
- assert(!operand.getType().isa<VectorType>() &&
+ assert(!isa<VectorType>(operand.getType()) &&
"Vector op not found in replacement map");
// Vectorize constant.
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 94203ec94274..01c7c77319f0 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -1852,7 +1852,7 @@ static void getMultiLevelStrides(const MemRefRegion &region,
int64_t numEltPerStride = 1;
int64_t stride = 1;
for (int d = bufferShape.size() - 1; d >= 1; d--) {
- int64_t dimSize = region.memref.getType().cast<MemRefType>().getDimSize(d);
+ int64_t dimSize = cast<MemRefType>(region.memref.getType()).getDimSize(d);
stride *= dimSize;
numEltPerStride *= bufferShape[d];
// A stride is needed only if the region has a shorter extent than the
@@ -1891,7 +1891,7 @@ generatePointWiseCopy(Location loc, Value memref, Value fastMemRef,
return ubMap.getNumInputs() == ubOperands.size();
}));
- unsigned rank = memref.getType().cast<MemRefType>().getRank();
+ unsigned rank = cast<MemRefType>(memref.getType()).getRank();
assert(lbMaps.size() == rank && "wrong number of lb maps");
assert(ubMaps.size() == rank && "wrong number of ub maps");
@@ -2003,7 +2003,7 @@ static LogicalResult generateCopy(
auto loc = region.loc;
auto memref = region.memref;
- auto memRefType = memref.getType().cast<MemRefType>();
+ auto memRefType = cast<MemRefType>(memref.getType());
if (!memRefType.getLayout().isIdentity()) {
LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
@@ -2276,7 +2276,7 @@ static bool getFullMemRefAsRegion(Operation *op, unsigned numParamLoopIVs,
assert(false && "expected load or store op");
return false;
}
- auto memRefType = region->memref.getType().cast<MemRefType>();
+ auto memRefType = cast<MemRefType>(region->memref.getType());
if (!memRefType.hasStaticShape())
return false;
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index e454567c9213..4e02b612b9bf 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1119,9 +1119,9 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
ArrayRef<Value> extraIndices, AffineMap indexRemap,
ArrayRef<Value> extraOperands, ArrayRef<Value> symbolOperands,
bool allowNonDereferencingOps) {
- unsigned newMemRefRank = newMemRef.getType().cast<MemRefType>().getRank();
+ unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType()).getRank();
(void)newMemRefRank; // unused in opt mode
- unsigned oldMemRefRank = oldMemRef.getType().cast<MemRefType>().getRank();
+ unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType()).getRank();
(void)oldMemRefRank; // unused in opt mode
if (indexRemap) {
assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
@@ -1134,8 +1134,8 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
}
// Assert same elemental type.
- assert(oldMemRef.getType().cast<MemRefType>().getElementType() ==
- newMemRef.getType().cast<MemRefType>().getElementType());
+ assert(cast<MemRefType>(oldMemRef.getType()).getElementType() ==
+ cast<MemRefType>(newMemRef.getType()).getElementType());
SmallVector<unsigned, 2> usePositions;
for (const auto &opEntry : llvm::enumerate(op->getOperands())) {
@@ -1172,7 +1172,7 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
// Perform index rewrites for the dereferencing op and then replace the op
NamedAttribute oldMapAttrPair =
affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef);
- AffineMap oldMap = oldMapAttrPair.getValue().cast<AffineMapAttr>().getValue();
+ AffineMap oldMap = cast<AffineMapAttr>(oldMapAttrPair.getValue()).getValue();
unsigned oldMapNumInputs = oldMap.getNumInputs();
SmallVector<Value, 4> oldMapOperands(
op->operand_begin() + memRefOperandPos + 1,
@@ -1294,9 +1294,9 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
ArrayRef<Value> symbolOperands, Operation *domOpFilter,
Operation *postDomOpFilter, bool allowNonDereferencingOps,
bool replaceInDeallocOp) {
- unsigned newMemRefRank = newMemRef.getType().cast<MemRefType>().getRank();
+ unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType()).getRank();
(void)newMemRefRank; // unused in opt mode
- unsigned oldMemRefRank = oldMemRef.getType().cast<MemRefType>().getRank();
+ unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType()).getRank();
(void)oldMemRefRank;
if (indexRemap) {
assert(indexRemap.getNumSymbols() == symbolOperands.size() &&
@@ -1309,8 +1309,8 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
}
// Assert same elemental type.
- assert(oldMemRef.getType().cast<MemRefType>().getElementType() ==
- newMemRef.getType().cast<MemRefType>().getElementType());
+ assert(cast<MemRefType>(oldMemRef.getType()).getElementType() ==
+ cast<MemRefType>(newMemRef.getType()).getElementType());
std::unique_ptr<DominanceInfo> domInfo;
std::unique_ptr<PostDominanceInfo> postDomInfo;
@@ -1734,7 +1734,7 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
(void)getTileSizePos(layoutMap, tileSizePos);
if (newMemRefType.getNumDynamicDims() > 0 && !tileSizePos.empty()) {
- MemRefType oldMemRefType = oldMemRef.getType().cast<MemRefType>();
+ MemRefType oldMemRefType = cast<MemRefType>(oldMemRef.getType());
SmallVector<Value, 4> newDynamicSizes;
createNewDynamicSizes(oldMemRefType, newMemRefType, layoutMap, allocOp, b,
newDynamicSizes);
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 9602d530cf82..85e07253c488 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -34,7 +34,7 @@ struct ConstantOpInterface
return constantOp->emitError("could not infer memory space");
// Only ranked tensors are supported.
- if (!constantOp.getType().isa<RankedTensorType>())
+ if (!isa<RankedTensorType>(constantOp.getType()))
return failure();
// Only constants inside a module are supported.
@@ -58,7 +58,7 @@ struct ConstantOpInterface
bool isWritable(Operation *op, Value value,
const AnalysisState &state) const {
// Memory locations returned by memref::GetGlobalOp may not be written to.
- assert(value.isa<OpResult>());
+ assert(isa<OpResult>(value));
return false;
}
};
@@ -84,21 +84,21 @@ struct IndexCastOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto castOp = cast<arith::IndexCastOp>(op);
- auto resultTensorType = castOp.getType().cast<TensorType>();
+ auto resultTensorType = cast<TensorType>(castOp.getType());
FailureOr<Value> source = getBuffer(rewriter, castOp.getIn(), options);
if (failed(source))
return failure();
- auto sourceType = source->getType().cast<BaseMemRefType>();
+ auto sourceType = cast<BaseMemRefType>(source->getType());
// Result type should have same layout and address space as the source type.
BaseMemRefType resultType;
- if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>()) {
+ if (auto rankedMemRefType = dyn_cast<MemRefType>(sourceType)) {
resultType = MemRefType::get(
rankedMemRefType.getShape(), resultTensorType.getElementType(),
rankedMemRefType.getLayout(), rankedMemRefType.getMemorySpace());
} else {
- auto unrankedMemrefType = sourceType.cast<UnrankedMemRefType>();
+ auto unrankedMemrefType = cast<UnrankedMemRefType>(sourceType);
resultType = UnrankedMemRefType::get(resultTensorType.getElementType(),
unrankedMemrefType.getMemorySpace());
}
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index 22ec425b4730..1a50b4ad5598 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -63,10 +63,10 @@ static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter,
Location loc, Type type,
const APInt &value) {
TypedAttr attr;
- if (auto intTy = type.dyn_cast<IntegerType>()) {
+ if (auto intTy = dyn_cast<IntegerType>(type)) {
attr = rewriter.getIntegerAttr(type, value);
} else {
- auto vecTy = type.cast<VectorType>();
+ auto vecTy = cast<VectorType>(type);
attr = SplatElementsAttr::get(vecTy, value);
}
@@ -78,10 +78,10 @@ static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter,
Location loc, Type type,
int64_t value) {
unsigned elementBitWidth = 0;
- if (auto intTy = type.dyn_cast<IntegerType>())
+ if (auto intTy = dyn_cast<IntegerType>(type))
elementBitWidth = intTy.getWidth();
else
- elementBitWidth = type.cast<VectorType>().getElementTypeBitWidth();
+ elementBitWidth = cast<VectorType>(type).getElementTypeBitWidth();
return createScalarOrSplatConstant(rewriter, loc, type,
APInt(elementBitWidth, value));
@@ -95,7 +95,7 @@ static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter,
static Value extractLastDimSlice(ConversionPatternRewriter &rewriter,
Location loc, Value input,
int64_t lastOffset) {
- ArrayRef<int64_t> shape = input.getType().cast<VectorType>().getShape();
+ ArrayRef<int64_t> shape = cast<VectorType>(input.getType()).getShape();
assert(lastOffset < shape.back() && "Offset out of bounds");
// Scalarize the result in case of 1D vectors.
@@ -125,7 +125,7 @@ extractLastDimHalves(ConversionPatternRewriter &rewriter, Location loc,
// `input` is a scalar, this is a noop.
static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter,
Location loc, Value input) {
- auto vecTy = input.getType().dyn_cast<VectorType>();
+ auto vecTy = dyn_cast<VectorType>(input.getType());
if (!vecTy)
return input;
@@ -142,7 +142,7 @@ static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter,
/// `input` is a scalar, this is a noop.
static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc,
Value input) {
- auto vecTy = input.getType().dyn_cast<VectorType>();
+ auto vecTy = dyn_cast<VectorType>(input.getType());
if (!vecTy)
return input;
@@ -159,11 +159,11 @@ static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc,
static Value insertLastDimSlice(ConversionPatternRewriter &rewriter,
Location loc, Value source, Value dest,
int64_t lastOffset) {
- ArrayRef<int64_t> shape = dest.getType().cast<VectorType>().getShape();
+ ArrayRef<int64_t> shape = cast<VectorType>(dest.getType()).getShape();
assert(lastOffset < shape.back() && "Offset out of bounds");
// Handle scalar source.
- if (source.getType().isa<IntegerType>())
+ if (isa<IntegerType>(source.getType()))
return rewriter.create<vector::InsertOp>(loc, source, dest, lastOffset);
SmallVector<int64_t> offsets(shape.size(), 0);
@@ -215,14 +215,14 @@ struct ConvertConstant final : OpConversionPattern<arith::ConstantOp> {
unsigned newBitWidth = newType.getElementTypeBitWidth();
Attribute oldValue = op.getValueAttr();
- if (auto intAttr = oldValue.dyn_cast<IntegerAttr>()) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(oldValue)) {
auto [low, high] = getHalves(intAttr.getValue(), newBitWidth);
auto newAttr = DenseElementsAttr::get(newType, {low, high});
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
return success();
}
- if (auto splatAttr = oldValue.dyn_cast<SplatElementsAttr>()) {
+ if (auto splatAttr = dyn_cast<SplatElementsAttr>(oldValue)) {
auto [low, high] =
getHalves(splatAttr.getSplatValue<APInt>(), newBitWidth);
int64_t numSplatElems = splatAttr.getNumElements();
@@ -238,7 +238,7 @@ struct ConvertConstant final : OpConversionPattern<arith::ConstantOp> {
return success();
}
- if (auto elemsAttr = oldValue.dyn_cast<DenseElementsAttr>()) {
+ if (auto elemsAttr = dyn_cast<DenseElementsAttr>(oldValue)) {
int64_t numElems = elemsAttr.getNumElements();
SmallVector<APInt> values;
values.reserve(numElems * 2);
@@ -527,9 +527,8 @@ struct ConvertMaxMin final : OpConversionPattern<SourceOp> {
Location loc = op->getLoc();
Type oldTy = op.getType();
- auto newTy = this->getTypeConverter()
- ->convertType(oldTy)
- .template dyn_cast_or_null<VectorType>();
+ auto newTy = dyn_cast_or_null<VectorType>(
+ this->getTypeConverter()->convertType(oldTy));
if (!newTy)
return rewriter.notifyMatchFailure(
loc, llvm::formatv("unsupported type: {0}", op.getType()));
@@ -549,11 +548,11 @@ struct ConvertMaxMin final : OpConversionPattern<SourceOp> {
/// Returns true iff the type is `index` or `vector<...index>`.
static bool isIndexOrIndexVector(Type type) {
- if (type.isa<IndexType>())
+ if (isa<IndexType>(type))
return true;
- if (auto vectorTy = type.dyn_cast<VectorType>())
- if (vectorTy.getElementType().isa<IndexType>())
+ if (auto vectorTy = dyn_cast<VectorType>(type))
+ if (isa<IndexType>(vectorTy.getElementType()))
return true;
return false;
@@ -610,7 +609,7 @@ struct ConvertIndexCastIndexToInt final : OpConversionPattern<CastOp> {
// Emit an index cast over the matching narrow type.
Type narrowTy =
rewriter.getIntegerType(typeConverter->getMaxTargetIntBitWidth());
- if (auto vecTy = resultType.dyn_cast<VectorType>())
+ if (auto vecTy = dyn_cast<VectorType>(resultType))
narrowTy = VectorType::get(vecTy.getShape(), narrowTy);
// Sign or zero-extend the result. Let the matching conversion pattern
@@ -1116,7 +1115,7 @@ arith::WideIntEmulationConverter::WideIntEmulationConverter(
// Vector case.
addConversion([this](VectorType ty) -> std::optional<Type> {
- auto intTy = ty.getElementType().dyn_cast<IntegerType>();
+ auto intTy = dyn_cast<IntegerType>(ty.getElementType());
if (!intTy)
return ty;
diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
index 787d4989bbab..8eddd811dbea 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
@@ -86,12 +86,12 @@ reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
continue;
}
- assert(value.getType().cast<ShapedType>().isDynamicDim(*dim) &&
+ assert(cast<ShapedType>(value.getType()).isDynamicDim(*dim) &&
"expected dynamic dim");
- if (value.getType().isa<RankedTensorType>()) {
+ if (isa<RankedTensorType>(value.getType())) {
// A tensor dimension is used: generate a tensor.dim.
operands.push_back(b.create<tensor::DimOp>(loc, value, *dim));
- } else if (value.getType().isa<MemRefType>()) {
+ } else if (isa<MemRefType>(value.getType())) {
// A memref dimension is used: generate a memref.dim.
operands.push_back(b.create<memref::DimOp>(loc, value, *dim));
} else {
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index 45a4bf74b915..fb363c82a069 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -58,7 +58,7 @@ Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
OpFoldResult ofr) {
if (auto value = ofr.dyn_cast<Value>())
return value;
- auto attr = ofr.dyn_cast<Attribute>().dyn_cast<IntegerAttr>();
+ auto attr = dyn_cast<IntegerAttr>(ofr.dyn_cast<Attribute>());
assert(attr && "expect the op fold result casts to an integer attribute");
return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
}
@@ -73,8 +73,8 @@ Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
if (targetIsIndex ^ valueIsIndex)
return b.create<arith::IndexCastOp>(loc, targetType, value);
- auto targetIntegerType = targetType.dyn_cast<IntegerType>();
- auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
+ auto targetIntegerType = dyn_cast<IntegerType>(targetType);
+ auto valueIntegerType = dyn_cast<IntegerType>(value.getType());
assert(targetIntegerType && valueIntegerType &&
"unexpected cast between types other than integers and index");
assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
@@ -88,9 +88,9 @@ Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand,
Type toType, bool isUnsignedCast) {
if (operand.getType() == toType)
return operand;
- if (auto toIntType = toType.dyn_cast<IntegerType>()) {
+ if (auto toIntType = dyn_cast<IntegerType>(toType)) {
// If operand is floating point, cast directly to the int type.
- if (operand.getType().isa<FloatType>()) {
+ if (isa<FloatType>(operand.getType())) {
if (isUnsignedCast)
return b.create<arith::FPToUIOp>(loc, toType, operand);
return b.create<arith::FPToSIOp>(loc, toType, operand);
@@ -98,7 +98,7 @@ Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand,
// Cast index operands directly to the int type.
if (operand.getType().isIndex())
return b.create<arith::IndexCastOp>(loc, toType, operand);
- if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
+ if (auto fromIntType = dyn_cast<IntegerType>(operand.getType())) {
// Either extend or truncate.
if (toIntType.getWidth() > fromIntType.getWidth()) {
if (isUnsignedCast)
@@ -108,15 +108,15 @@ Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand,
if (toIntType.getWidth() < fromIntType.getWidth())
return b.create<arith::TruncIOp>(loc, toType, operand);
}
- } else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
+ } else if (auto toFloatType = dyn_cast<FloatType>(toType)) {
// If operand is integer, cast directly to the float type.
// Note that it is unclear how to cast from BF16<->FP16.
- if (operand.getType().isa<IntegerType>()) {
+ if (isa<IntegerType>(operand.getType())) {
if (isUnsignedCast)
return b.create<arith::UIToFPOp>(loc, toFloatType, operand);
return b.create<arith::SIToFPOp>(loc, toFloatType, operand);
}
- if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
+ if (auto fromFloatType = dyn_cast<FloatType>(operand.getType())) {
if (toFloatType.getWidth() > fromFloatType.getWidth())
return b.create<arith::ExtFOp>(loc, toFloatType, operand);
if (toFloatType.getWidth() < fromFloatType.getWidth())
@@ -141,27 +141,27 @@ Value ArithBuilder::_and(Value lhs, Value rhs) {
return b.create<arith::AndIOp>(loc, lhs, rhs);
}
Value ArithBuilder::add(Value lhs, Value rhs) {
- if (lhs.getType().isa<FloatType>())
+ if (isa<FloatType>(lhs.getType()))
return b.create<arith::AddFOp>(loc, lhs, rhs);
return b.create<arith::AddIOp>(loc, lhs, rhs);
}
Value ArithBuilder::sub(Value lhs, Value rhs) {
- if (lhs.getType().isa<FloatType>())
+ if (isa<FloatType>(lhs.getType()))
return b.create<arith::SubFOp>(loc, lhs, rhs);
return b.create<arith::SubIOp>(loc, lhs, rhs);
}
Value ArithBuilder::mul(Value lhs, Value rhs) {
- if (lhs.getType().isa<FloatType>())
+ if (isa<FloatType>(lhs.getType()))
return b.create<arith::MulFOp>(loc, lhs, rhs);
return b.create<arith::MulIOp>(loc, lhs, rhs);
}
Value ArithBuilder::sgt(Value lhs, Value rhs) {
- if (lhs.getType().isa<FloatType>())
+ if (isa<FloatType>(lhs.getType()))
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, lhs, rhs);
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, lhs, rhs);
}
Value ArithBuilder::slt(Value lhs, Value rhs) {
- if (lhs.getType().isa<FloatType>())
+ if (isa<FloatType>(lhs.getType()))
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, lhs, rhs);
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, lhs, rhs);
}
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
index 7db078ad3f0a..04f131ec51cb 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp
@@ -528,9 +528,9 @@ void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() {
Operation *op = operand.getOwner();
Type type = operand.get().getType();
- bool isToken = type.isa<TokenType>();
- bool isGroup = type.isa<GroupType>();
- bool isValue = type.isa<ValueType>();
+ bool isToken = isa<TokenType>(type);
+ bool isGroup = isa<GroupType>(type);
+ bool isValue = isa<ValueType>(type);
// Drop reference after async token or group error check (coro await).
if (auto await = dyn_cast<RuntimeIsErrorOp>(op))
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 25cb61857a10..db7550d7d99f 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -161,7 +161,7 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) {
// We treat TokenType as state update marker to represent side-effects of
// async computations
- bool isStateful = func.getCallableResults().front().isa<TokenType>();
+ bool isStateful = isa<TokenType>(func.getCallableResults().front());
std::optional<Value> retToken;
if (isStateful)
@@ -535,7 +535,7 @@ public:
ConversionPatternRewriter &rewriter) const override {
// We can only await on one the `AwaitableType` (for `await` it can be
// a `token` or a `value`, for `await_all` it must be a `group`).
- if (!op.getOperand().getType().template isa<AwaitableType>())
+ if (!isa<AwaitableType>(op.getOperand().getType()))
return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
// Check if await operation is inside the coroutine function.
@@ -646,7 +646,7 @@ public:
getReplacementValue(AwaitOp op, Value operand,
ConversionPatternRewriter &rewriter) const override {
// Load from the async value storage.
- auto valueType = operand.getType().cast<ValueType>().getValueType();
+ auto valueType = cast<ValueType>(operand.getType()).getValueType();
return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
}
};
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index ed95a62b9b6f..5e36b55cff84 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -59,7 +59,7 @@ transform::OneShotBufferizeOp::apply(TransformResults &transformResults,
// This transform op is currently restricted to ModuleOps and function ops.
// Such ops are modified in-place.
- transformResults.set(getTransformed().cast<OpResult>(), payloadOps);
+ transformResults.set(cast<OpResult>(getTransformed()), payloadOps);
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
index cf51aa58a93a..b813b2425bdd 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
@@ -280,7 +280,7 @@ private:
// defined in a non-dominated block or it is defined in the same block
// but the current value is not dominated by the source value.
if (!dominators.dominates(definingBlock, parentBlock) ||
- (definingBlock == parentBlock && value.isa<BlockArgument>())) {
+ (definingBlock == parentBlock && isa<BlockArgument>(value))) {
toProcess.emplace_back(value, parentBlock);
valuesToFree.insert(value);
} else if (visitedValues.insert(std::make_tuple(value, definingBlock))
@@ -307,8 +307,8 @@ private:
// Add new allocs and additional clone operations.
for (Value value : valuesToFree) {
- if (failed(value.isa<BlockArgument>()
- ? introduceBlockArgCopy(value.cast<BlockArgument>())
+ if (failed(isa<BlockArgument>(value)
+ ? introduceBlockArgCopy(cast<BlockArgument>(value))
: introduceValueCopyForRegionResult(value)))
return failure();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp
index 83b2ef6a6dac..278664abcf49 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp
@@ -43,7 +43,7 @@ static bool isKnownControlFlowInterface(Operation *op) {
/// exceed the stack space.
static bool defaultIsSmallAlloc(Value alloc, unsigned maximumSizeInBytes,
unsigned maxRankOfAllocatedMemRef) {
- auto type = alloc.getType().dyn_cast<ShapedType>();
+ auto type = dyn_cast<ShapedType>(alloc.getType());
if (!type || !alloc.getDefiningOp<memref::AllocOp>())
return false;
if (!type.hasStaticShape()) {
@@ -355,7 +355,7 @@ public:
OpBuilder builder(startOperation);
Operation *allocOp = alloc.getDefiningOp();
Operation *alloca = builder.create<memref::AllocaOp>(
- alloc.getLoc(), alloc.getType().cast<MemRefType>(),
+ alloc.getLoc(), cast<MemRefType>(alloc.getType()),
allocOp->getOperands(), allocOp->getAttrs());
// Replace the original alloc by a newly created alloca.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index 7b63335d43fb..dd359c2dcca5 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -53,7 +53,7 @@ updateFuncOp(func::FuncOp func,
SmallVector<Type, 6> erasedResultTypes;
BitVector erasedResultIndices(functionType.getNumResults());
for (const auto &resultType : llvm::enumerate(functionType.getResults())) {
- if (auto memrefType = resultType.value().dyn_cast<MemRefType>()) {
+ if (auto memrefType = dyn_cast<MemRefType>(resultType.value())) {
if (!hasStaticIdentityLayout(memrefType) &&
!hasFullyDynamicLayoutMap(memrefType)) {
// Only buffers with static identity layout can be allocated. These can
@@ -103,7 +103,7 @@ static void updateReturnOps(func::FuncOp func,
SmallVector<Value, 6> copyIntoOutParams;
SmallVector<Value, 6> keepAsReturnOperands;
for (Value operand : op.getOperands()) {
- if (operand.getType().isa<MemRefType>())
+ if (isa<MemRefType>(operand.getType()))
copyIntoOutParams.push_back(operand);
else
keepAsReturnOperands.push_back(operand);
@@ -137,7 +137,7 @@ updateCalls(ModuleOp module,
SmallVector<Value, 6> replaceWithNewCallResults;
SmallVector<Value, 6> replaceWithOutParams;
for (OpResult result : op.getResults()) {
- if (result.getType().isa<MemRefType>())
+ if (isa<MemRefType>(result.getType()))
replaceWithOutParams.push_back(result);
else
replaceWithNewCallResults.push_back(result);
@@ -145,13 +145,13 @@ updateCalls(ModuleOp module,
SmallVector<Value, 6> outParams;
OpBuilder builder(op);
for (Value memref : replaceWithOutParams) {
- if (!memref.getType().cast<MemRefType>().hasStaticShape()) {
+ if (!cast<MemRefType>(memref.getType()).hasStaticShape()) {
op.emitError()
<< "cannot create out param for dynamically shaped result";
didFail = true;
return;
}
- auto memrefType = memref.getType().cast<MemRefType>();
+ auto memrefType = cast<MemRefType>(memref.getType());
auto allocType =
MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
AffineMap(), memrefType.getMemorySpace());
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index b9776e2fb209..f8231cac778a 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -68,7 +68,7 @@ void BufferPlacementAllocs::build(Operation *op) {
[=](MemoryEffects::EffectInstance &it) {
Value value = it.getValue();
return isa<MemoryEffects::Allocate>(it.getEffect()) && value &&
- value.isa<OpResult>() &&
+ isa<OpResult>(value) &&
it.getResource() !=
SideEffects::AutomaticAllocationScopeResource::get();
});
@@ -149,7 +149,7 @@ bool BufferPlacementTransformationBase::isLoop(Operation *op) {
FailureOr<memref::GlobalOp>
bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
Attribute memorySpace) {
- auto type = constantOp.getType().cast<RankedTensorType>();
+ auto type = cast<RankedTensorType>(constantOp.getType());
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
@@ -185,14 +185,14 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
: IntegerAttr();
BufferizeTypeConverter typeConverter;
- auto memrefType = typeConverter.convertType(type).cast<MemRefType>();
+ auto memrefType = cast<MemRefType>(typeConverter.convertType(type));
if (memorySpace)
memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace);
auto global = globalBuilder.create<memref::GlobalOp>(
constantOp.getLoc(), (Twine("__constant_") + os.str()).str(),
/*sym_visibility=*/globalBuilder.getStringAttr("private"),
/*type=*/memrefType,
- /*initial_value=*/constantOp.getValue().cast<ElementsAttr>(),
+ /*initial_value=*/cast<ElementsAttr>(constantOp.getValue()),
/*constant=*/true,
/*alignment=*/memrefAlignment);
symbolTable.insert(global);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 4eabfccf2514..24aaff0e4882 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -44,7 +44,7 @@ using namespace mlir::bufferization;
static Value materializeToTensor(OpBuilder &builder, TensorType type,
ValueRange inputs, Location loc) {
assert(inputs.size() == 1);
- assert(inputs[0].getType().isa<BaseMemRefType>());
+ assert(isa<BaseMemRefType>(inputs[0].getType()));
return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
}
@@ -66,11 +66,11 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1 && "expected exactly one input");
- if (auto inputType = inputs[0].getType().dyn_cast<MemRefType>()) {
+ if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
// MemRef to MemRef cast.
assert(inputType != type && "expected different types");
// Unranked to ranked and ranked to unranked casts must be explicit.
- auto rankedDestType = type.dyn_cast<MemRefType>();
+ auto rankedDestType = dyn_cast<MemRefType>(type);
if (!rankedDestType)
return nullptr;
FailureOr<Value> replacement =
@@ -80,7 +80,7 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
return *replacement;
}
- if (inputs[0].getType().isa<TensorType>()) {
+ if (isa<TensorType>(inputs[0].getType())) {
// Tensor to MemRef cast.
return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
}
@@ -222,7 +222,7 @@ struct OneShotBufferizePass
parseLayoutMapOption(unknownTypeConversion);
opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace,
const BufferizationOptions &options) {
- auto tensorType = value.getType().cast<TensorType>();
+ auto tensorType = cast<TensorType>(value.getType());
if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
return bufferization::getMemRefTypeWithStaticIdentityLayout(
tensorType, memorySpace);
@@ -325,7 +325,7 @@ mlir::bufferization::createFinalizingBufferizePass() {
// BufferizableOpInterface-based Bufferization
//===----------------------------------------------------------------------===//
-static bool isaTensor(Type t) { return t.isa<TensorType>(); }
+static bool isaTensor(Type t) { return isa<TensorType>(t); }
/// Return true if the given op has a tensor result or a tensor operand.
static bool hasTensorSemantics(Operation *op) {
@@ -549,7 +549,7 @@ BufferizationOptions bufferization::getPartialBufferizationOptions() {
options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
const BufferizationOptions &options) {
return getMemRefTypeWithStaticIdentityLayout(
- value.getType().cast<TensorType>(), memorySpace);
+ cast<TensorType>(value.getType()), memorySpace);
};
options.opFilter.allowDialect<BufferizationDialect>();
return options;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 5fc12573912f..58475d225ce8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -33,12 +33,12 @@ neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
Operation *insertionPoint,
const SmallVector<Value> &neededValues) {
for (Value val : neededValues) {
- if (auto bbArg = val.dyn_cast<BlockArgument>()) {
+ if (auto bbArg = dyn_cast<BlockArgument>(val)) {
Block *owner = bbArg.getOwner();
if (!owner->findAncestorOpInBlock(*insertionPoint))
return false;
} else {
- auto opResult = val.cast<OpResult>();
+ auto opResult = cast<OpResult>(val);
if (!domInfo.dominates(opResult.getOwner(), insertionPoint))
return false;
}
@@ -75,7 +75,7 @@ findValidInsertionPoint(Operation *emptyTensorOp,
// * in case of an OpResult: There must be at least one op right after the
// defining op (the anchor op or one of its
// parents).
- if (auto bbArg = val.dyn_cast<BlockArgument>()) {
+ if (auto bbArg = dyn_cast<BlockArgument>(val)) {
insertionPointCandidates.push_back(
&bbArg.getOwner()->getOperations().front());
} else {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index bf14e466190b..f73efc120d37 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -60,7 +60,7 @@ static BaseMemRefType
getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
const BufferizationOptions &options) {
auto tensorType =
- funcOp.getFunctionType().getInput(index).dyn_cast<TensorType>();
+ dyn_cast<TensorType>(funcOp.getFunctionType().getInput(index));
assert(tensorType && "expected TensorType");
BaseMemRefType memrefType = options.functionArgTypeConverterFn(
@@ -71,7 +71,7 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
if (!layoutAttr)
return memrefType;
- auto rankedMemrefType = memrefType.dyn_cast<MemRefType>();
+ auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
return MemRefType::get(
rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
@@ -224,7 +224,7 @@ struct CallOpInterface
for (const auto &it : llvm::enumerate(callOp.getResultTypes())) {
unsigned returnValIdx = it.index();
Type returnType = it.value();
- if (!returnType.isa<TensorType>()) {
+ if (!isa<TensorType>(returnType)) {
// Non-tensor values are returned.
retValMapping[returnValIdx] = resultTypes.size();
resultTypes.push_back(returnType);
@@ -242,7 +242,7 @@ struct CallOpInterface
Value tensorOperand = opOperand.get();
// Non-tensor operands are just copied.
- if (!tensorOperand.getType().isa<TensorType>()) {
+ if (!isa<TensorType>(tensorOperand.getType())) {
newOperands[idx] = tensorOperand;
continue;
}
@@ -342,7 +342,7 @@ struct FuncOpInterface
SmallVector<Type> argTypes;
for (const auto &it : llvm::enumerate(funcType.getInputs())) {
Type argType = it.value();
- if (auto tensorType = argType.dyn_cast<TensorType>()) {
+ if (auto tensorType = dyn_cast<TensorType>(argType)) {
argTypes.push_back(
getBufferizedFunctionArgType(funcOp, it.index(), options));
continue;
@@ -356,7 +356,7 @@ struct FuncOpInterface
if (funcOp.getBody().empty()) {
SmallVector<Type> retTypes;
for (Type resultType : funcType.getResults()) {
- if (resultType.isa<TensorType>())
+ if (isa<TensorType>(resultType))
return funcOp->emitError() << "cannot bufferize bodiless function "
<< "that returns a tensor";
retTypes.push_back(resultType);
@@ -373,7 +373,7 @@ struct FuncOpInterface
// 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg.
Block &frontBlock = funcOp.getBody().front();
for (BlockArgument &bbArg : frontBlock.getArguments()) {
- auto tensorType = bbArg.getType().dyn_cast<TensorType>();
+ auto tensorType = dyn_cast<TensorType>(bbArg.getType());
// Non-tensor types stay the same.
if (!tensorType)
continue;
@@ -404,7 +404,7 @@ struct FuncOpInterface
SmallVector<Value> returnValues;
for (OpOperand &returnOperand : returnOp->getOpOperands()) {
Value returnVal = returnOperand.get();
- auto tensorType = returnVal.getType().dyn_cast<TensorType>();
+ auto tensorType = dyn_cast<TensorType>(returnVal.getType());
rewriter.setInsertionPoint(returnOp);
// If not a tensor type just forward it.
@@ -436,7 +436,7 @@ struct FuncOpInterface
bool isWritable(Operation *op, Value value,
const AnalysisState &state) const {
auto funcOp = cast<FuncOp>(op);
- BlockArgument bbArg = value.dyn_cast<BlockArgument>();
+ BlockArgument bbArg = dyn_cast<BlockArgument>(value);
assert(bbArg && "expected BlockArgument");
// "bufferization.writable" overrides other writability decisions. This is
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index db7d4533cd9a..6da512699cc7 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -66,7 +66,7 @@ MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::OneShotAnalysisState)
using namespace mlir;
using namespace mlir::bufferization;
-static bool isaTensor(Type t) { return t.isa<TensorType>(); }
+static bool isaTensor(Type t) { return isa<TensorType>(t); }
//===----------------------------------------------------------------------===//
// Bufferization-specific attribute manipulation.
@@ -85,11 +85,11 @@ static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) {
SmallVector<StringRef> inPlaceVector;
if (auto attr = op->getAttr(kInPlaceOperandsAttrName)) {
inPlaceVector = SmallVector<StringRef>(llvm::to_vector<4>(
- attr.cast<ArrayAttr>().getAsValueRange<StringAttr>()));
+ cast<ArrayAttr>(attr).getAsValueRange<StringAttr>()));
} else {
inPlaceVector = SmallVector<StringRef>(op->getNumOperands(), "none");
for (OpOperand &opOperand : op->getOpOperands())
- if (opOperand.get().getType().isa<TensorType>())
+ if (isa<TensorType>(opOperand.get().getType()))
inPlaceVector[opOperand.getOperandNumber()] = "false";
}
inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false";
@@ -107,12 +107,12 @@ OneShotAnalysisState::OneShotAnalysisState(
// Set up alias sets.
op->walk([&](Operation *op) {
for (Value v : op->getResults())
- if (v.getType().isa<TensorType>())
+ if (isa<TensorType>(v.getType()))
createAliasInfoEntry(v);
for (Region &r : op->getRegions())
for (Block &b : r.getBlocks())
for (auto bbArg : b.getArguments())
- if (bbArg.getType().isa<TensorType>())
+ if (isa<TensorType>(bbArg.getType()))
createAliasInfoEntry(bbArg);
});
@@ -121,7 +121,7 @@ OneShotAnalysisState::OneShotAnalysisState(
if (!options.isOpAllowed(bufferizableOp))
return WalkResult::skip();
for (OpOperand &opOperand : bufferizableOp->getOpOperands())
- if (opOperand.get().getType().isa<TensorType>())
+ if (isa<TensorType>(opOperand.get().getType()))
if (bufferizableOp.mustBufferizeInPlace(opOperand, *this))
bufferizeInPlace(opOperand);
return WalkResult::advance();
@@ -187,13 +187,13 @@ void OneShotAnalysisState::gatherYieldedTensors(Operation *op) {
for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
Value returnVal = returnValOperand.get();
// Skip non-tensor values.
- if (!returnVal.getType().isa<TensorType>())
+ if (!isa<TensorType>(returnVal.getType()))
continue;
// Add all aliases of the returned value. But only the ones that are in
// the same block.
applyOnAliases(returnVal, [&](Value v) {
- if (auto bbArg = v.dyn_cast<BlockArgument>()) {
+ if (auto bbArg = dyn_cast<BlockArgument>(v)) {
if (bbArg.getOwner()->getParentOp() == returnOp->getParentOp())
yieldedTensors.insert(bbArg);
return;
@@ -217,7 +217,7 @@ void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
// Check all tensor OpResults.
for (OpResult opResult : op->getOpResults()) {
- if (!opResult.getType().isa<TensorType>())
+ if (!isa<TensorType>(opResult.getType()))
continue;
// If there is no preceding definition, the tensor contents are
@@ -259,7 +259,7 @@ bool OneShotAnalysisState::isWritable(Value value) const {
return bufferizableOp.isWritable(value, *this);
// Query BufferizableOpInterface to see if the BlockArgument is writable.
- if (auto bbArg = value.dyn_cast<BlockArgument>())
+ if (auto bbArg = dyn_cast<BlockArgument>(value))
if (auto bufferizableOp =
getOptions().dynCastBufferizableOp(bbArg.getOwner()->getParentOp()))
return bufferizableOp.isWritable(bbArg, *this);
@@ -431,12 +431,12 @@ static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]";
readingOp->setAttr(readAttr, b.getUnitAttr());
- if (auto opResult = definition.dyn_cast<OpResult>()) {
+ if (auto opResult = dyn_cast<OpResult>(definition)) {
std::string defAttr =
id + "[DEF: result " + std::to_string(opResult.getResultNumber()) + "]";
opResult.getDefiningOp()->setAttr(defAttr, b.getUnitAttr());
} else {
- auto bbArg = definition.cast<BlockArgument>();
+ auto bbArg = cast<BlockArgument>(definition);
std::string defAttr =
id + "[DEF: bbArg " + std::to_string(bbArg.getArgNumber()) + "]";
bbArg.getOwner()->getParentOp()->setAttr(defAttr, b.getUnitAttr());
@@ -581,7 +581,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
continue;
}
} else {
- auto bbArg = definition.cast<BlockArgument>();
+ auto bbArg = cast<BlockArgument>(definition);
Block *block = bbArg.getOwner();
if (!block->findAncestorOpInBlock(*conflictingWritingOp)) {
LLVM_DEBUG(llvm::dbgs() << " no conflict: definition is bbArg "
@@ -715,12 +715,12 @@ static void annotateNonWritableTensor(Value value) {
static int64_t counter = 0;
OpBuilder b(value.getContext());
std::string id = "W_" + std::to_string(counter++);
- if (auto opResult = value.dyn_cast<OpResult>()) {
+ if (auto opResult = dyn_cast<OpResult>(value)) {
std::string attr = id + "[NOT-WRITABLE: result " +
std::to_string(opResult.getResultNumber()) + "]";
opResult.getDefiningOp()->setAttr(attr, b.getUnitAttr());
} else {
- auto bbArg = value.cast<BlockArgument>();
+ auto bbArg = cast<BlockArgument>(value);
std::string attr = id + "[NOT-WRITABLE: bbArg " +
std::to_string(bbArg.getArgNumber()) + "]";
bbArg.getOwner()->getParentOp()->setAttr(attr, b.getUnitAttr());
@@ -812,7 +812,7 @@ LogicalResult
OneShotAnalysisState::analyzeSingleOp(Operation *op,
const DominanceInfo &domInfo) {
for (OpOperand &opOperand : op->getOpOperands())
- if (opOperand.get().getType().isa<TensorType>())
+ if (isa<TensorType>(opOperand.get().getType()))
if (failed(bufferizableInPlaceAnalysisImpl(opOperand, *this, domInfo)))
return failure();
return success();
@@ -831,7 +831,7 @@ static void equivalenceAnalysis(SmallVector<Operation *> &ops,
for (Operation *op : ops) {
if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) {
for (OpResult opResult : op->getOpResults()) {
- if (!opResult.getType().isa<TensorType>())
+ if (!isa<TensorType>(opResult.getType()))
continue;
AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
if (aliases.getNumAliases() == 0)
@@ -958,7 +958,7 @@ static LogicalResult checkAliasInfoConsistency(Operation *op,
}
for (OpOperand &opOperand : op->getOpOperands()) {
- if (opOperand.get().getType().isa<TensorType>()) {
+ if (isa<TensorType>(opOperand.get().getType())) {
if (wouldCreateReadAfterWriteInterference(
opOperand, domInfo, state,
/*checkConsistencyOnly=*/true)) {
@@ -984,7 +984,7 @@ annotateOpsWithBufferizationMarkers(Operation *op,
// Add __inplace_operands_attr__.
op->walk([&](Operation *op) {
for (OpOperand &opOperand : op->getOpOperands())
- if (opOperand.get().getType().isa<TensorType>())
+ if (isa<TensorType>(opOperand.get().getType()))
setInPlaceOpOperand(opOperand, state.isInPlace(opOperand));
});
}
@@ -1031,12 +1031,12 @@ static LogicalResult assertNoAllocsReturned(Operation *op,
for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
Value returnVal = returnValOperand.get();
// Skip non-tensor values.
- if (!returnVal.getType().isa<TensorType>())
+ if (!isa<TensorType>(returnVal.getType()))
continue;
bool foundEquivValue = false;
state.applyOnEquivalenceClass(returnVal, [&](Value equivVal) {
- if (auto bbArg = equivVal.dyn_cast<BlockArgument>()) {
+ if (auto bbArg = dyn_cast<BlockArgument>(equivVal)) {
Operation *definingOp = bbArg.getOwner()->getParentOp();
if (definingOp->isProperAncestor(returnOp))
foundEquivValue = true;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 27b560afdbb3..d0af1c278c14 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -109,9 +109,9 @@ static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
SmallVector<int64_t> equivBbArgs;
if (op->hasAttr(kEquivalentArgsAttr)) {
- auto attr = op->getAttr(kEquivalentArgsAttr).cast<ArrayAttr>();
+ auto attr = cast<ArrayAttr>(op->getAttr(kEquivalentArgsAttr));
equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) {
- return a.cast<IntegerAttr>().getValue().getSExtValue();
+ return cast<IntegerAttr>(a).getValue().getSExtValue();
}));
} else {
equivBbArgs.append(op->getNumOperands(), -1);
@@ -132,10 +132,10 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
// return value may alias with any tensor bbArg.
FunctionType type = funcOp.getFunctionType();
for (const auto &inputIt : llvm::enumerate(type.getInputs())) {
- if (!inputIt.value().isa<TensorType>())
+ if (!isa<TensorType>(inputIt.value()))
continue;
for (const auto &resultIt : llvm::enumerate(type.getResults())) {
- if (!resultIt.value().isa<TensorType>())
+ if (!isa<TensorType>(resultIt.value()))
continue;
int64_t returnIdx = resultIt.index();
int64_t bbArgIdx = inputIt.index();
@@ -150,9 +150,9 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
assert(returnOp && "expected func with single return op");
for (OpOperand &returnVal : returnOp->getOpOperands())
- if (returnVal.get().getType().isa<RankedTensorType>())
+ if (isa<RankedTensorType>(returnVal.get().getType()))
for (BlockArgument bbArg : funcOp.getArguments())
- if (bbArg.getType().isa<RankedTensorType>()) {
+ if (isa<RankedTensorType>(bbArg.getType())) {
int64_t returnIdx = returnVal.getOperandNumber();
int64_t bbArgIdx = bbArg.getArgNumber();
if (state.areEquivalentBufferizedValues(returnVal.get(), bbArg)) {
@@ -193,7 +193,7 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
++idx) {
// Skip non-tensor arguments.
- if (!funcOp.getFunctionType().getInput(idx).isa<TensorType>())
+ if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx)))
continue;
bool isRead;
bool isWritten;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
index 4cd19b4efc63..b12ea25396b2 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
@@ -58,7 +58,7 @@ resolveUsesInRepetitiveRegions(Operation *op,
for (OpOperand &opOperand : bufferizableOp->getOpOperands()) {
Value operand = opOperand.get();
// Skip non-tensor operands.
- if (!operand.getType().isa<TensorType>())
+ if (!isa<TensorType>(operand.getType()))
continue;
// Skip operands that do not bufferize to memory writes.
if (!bufferizableOp.bufferizesToMemoryWrite(opOperand, state))
@@ -85,7 +85,7 @@ resolveUsesInRepetitiveRegions(Operation *op,
// Insert a tensor copy and replace all uses inside of repetitive regions.
rewriter.setInsertionPoint(bufferizableOp);
auto tensorCopy = rewriter.create<AllocTensorOp>(
- bufferizableOp->getLoc(), operand.getType().cast<TensorType>(),
+ bufferizableOp->getLoc(), cast<TensorType>(operand.getType()),
/*dynamicSizes=*/ValueRange(),
/*copy=*/operand, /*memory_space=*/IntegerAttr());
for (OpOperand *use : usesInsideRegion)
@@ -137,7 +137,7 @@ mlir::bufferization::insertTensorCopies(Operation *op,
SmallVector<bool> escapeAttrValue;
bool foundTensorResult = false;
for (OpResult opResult : op->getOpResults()) {
- if (!opResult.getType().isa<TensorType>() ||
+ if (!isa<TensorType>(opResult.getType()) ||
!bufferizableOp.bufferizesToAllocation(opResult)) {
escapeAttrValue.push_back(false);
continue;
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 39c9e5e1725a..8cd2ccfcf882 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -257,19 +257,19 @@ checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
bool hasBlockMapping =
llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
- return attr.isa<GPUBlockMappingAttr>();
+ return isa<GPUBlockMappingAttr>(attr);
});
bool hasThreadMapping =
llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
- return attr.isa<GPUThreadMappingAttr>();
+ return isa<GPUThreadMappingAttr>(attr);
});
bool hasWarpMapping =
llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
- return attr.isa<GPUWarpMappingAttr>();
+ return isa<GPUWarpMappingAttr>(attr);
});
bool hasLinearMapping =
llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) {
- return attr.isa<GPULinearIdMappingAttr>();
+ return isa<GPULinearIdMappingAttr>(attr);
});
int64_t countMappingTypes = 0;
countMappingTypes += hasBlockMapping ? 1 : 0;
@@ -520,7 +520,7 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
ArrayRef<Attribute>{forallMappingAttrs}.take_front(
forallOp.getInductionVars().size()))) {
Value peIdOp = mappingIdOps[static_cast<int64_t>(
- dim.cast<DeviceMappingAttrInterface>().getMappingId())];
+ cast<DeviceMappingAttrInterface>(dim).getMappingId())];
bvm.map(iv, peIdOp);
}
diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
index 0a584a7920e0..ca9f2ac254c5 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
@@ -214,7 +214,7 @@ private:
/// Returns an accumulator factory that creates an op specified by opName.
AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
- bool isFloatingPoint = valueType.isa<FloatType>();
+ bool isFloatingPoint = isa<FloatType>(valueType);
switch (opName) {
case gpu::AllReduceOperation::ADD:
return isFloatingPoint ? getFactory<arith::AddFOp>()
diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
index 0890bf267762..1fbe66ff98d7 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
@@ -158,9 +158,9 @@ async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp,
transform(executeOp.getResultTypes(), std::back_inserter(resultTypes),
[](Type type) {
// Extract value type from !async.value.
- if (auto valueType = type.dyn_cast<async::ValueType>())
+ if (auto valueType = dyn_cast<async::ValueType>(type))
return valueType.getValueType();
- assert(type.isa<async::TokenType>() && "expected token type");
+ assert(isa<async::TokenType>(type) && "expected token type");
return type;
});
transform(results, std::back_inserter(resultTypes),
@@ -305,9 +305,9 @@ struct GpuAsyncRegionPass::SingleTokenUseCallback {
executeOp.getBodyResults(), [](OpResult result) {
if (result.use_empty() || result.hasOneUse())
return false;
- auto valueType = result.getType().dyn_cast<async::ValueType>();
+ auto valueType = dyn_cast<async::ValueType>(result.getType());
return valueType &&
- valueType.getValueType().isa<gpu::AsyncTokenType>();
+ isa<gpu::AsyncTokenType>(valueType.getValueType());
});
if (multiUseResults.empty())
return;
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index 91c1c763f070..b1e2f914db4c 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -338,7 +338,7 @@ public:
if (!resultAttr)
return failure();
- dataLayoutSpec = resultAttr.dyn_cast<DataLayoutSpecInterface>();
+ dataLayoutSpec = dyn_cast<DataLayoutSpecInterface>(resultAttr);
if (!dataLayoutSpec)
return failure();
}
@@ -410,7 +410,7 @@ private:
SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) {
for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
StringRef symbolName =
- symbolUse.getSymbolRef().cast<FlatSymbolRefAttr>().getValue();
+ cast<FlatSymbolRefAttr>(symbolUse.getSymbolRef()).getValue();
if (symbolTable.lookup(symbolName))
continue;
diff --git a/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp b/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp
index ea9c3969c413..21de15e25088 100644
--- a/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp
@@ -30,7 +30,7 @@ using namespace mlir::gpu;
/// single-iteration loops. Maps the innermost loops to thread dimensions, in
/// reverse order to enable access coalescing in the innermost loop.
static void insertCopyLoops(ImplicitLocOpBuilder &b, Value from, Value to) {
- auto memRefType = from.getType().cast<MemRefType>();
+ auto memRefType = cast<MemRefType>(from.getType());
auto rank = memRefType.getRank();
SmallVector<Value, 4> lbs, ubs, steps;
@@ -121,8 +121,8 @@ static void insertCopyLoops(ImplicitLocOpBuilder &b, Value from, Value to) {
/// pointed to by "from". In case a smaller block would be sufficient, the
/// caller can create a subview of the memref and promote it instead.
static void insertCopies(Region &region, Location loc, Value from, Value to) {
- auto fromType = from.getType().cast<MemRefType>();
- auto toType = to.getType().cast<MemRefType>();
+ auto fromType = cast<MemRefType>(from.getType());
+ auto toType = cast<MemRefType>(to.getType());
(void)fromType;
(void)toType;
assert(fromType.getShape() == toType.getShape());
@@ -143,7 +143,7 @@ static void insertCopies(Region &region, Location loc, Value from, Value to) {
/// copies will be inserted in the beginning and in the end of the function.
void mlir::promoteToWorkgroupMemory(GPUFuncOp op, unsigned arg) {
Value value = op.getArgument(arg);
- auto type = value.getType().dyn_cast<MemRefType>();
+ auto type = dyn_cast<MemRefType>(value.getType());
assert(type && type.hasStaticShape() && "can only promote memrefs");
// Get the type of the buffer in the workgroup memory.
diff --git a/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp b/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp
index 71d27764f437..8b09f441a4bd 100644
--- a/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp
+++ b/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp
@@ -67,7 +67,7 @@ LogicalResult DynParametricAttrConstraint::verify(
ConstraintVerifier &context) const {
// Check that the base is the expected one.
- auto dynAttr = attr.dyn_cast<DynamicAttr>();
+ auto dynAttr = dyn_cast<DynamicAttr>(attr);
if (!dynAttr || dynAttr.getAttrDef() != attrDef) {
if (emitError) {
StringRef dialectName = attrDef->getDialect()->getNamespace();
@@ -102,7 +102,7 @@ LogicalResult DynParametricTypeConstraint::verify(
function_ref<InFlightDiagnostic()> emitError, Attribute attr,
ConstraintVerifier &context) const {
// Check that the base is a TypeAttr.
- auto typeAttr = attr.dyn_cast<TypeAttr>();
+ auto typeAttr = dyn_cast<TypeAttr>(attr);
if (!typeAttr) {
if (emitError)
return emitError() << "expected type, got attribute '" << attr;
@@ -110,7 +110,7 @@ LogicalResult DynParametricTypeConstraint::verify(
}
// Check that the type base is the expected one.
- auto dynType = typeAttr.getValue().dyn_cast<DynamicType>();
+ auto dynType = dyn_cast<DynamicType>(typeAttr.getValue());
if (!dynType || dynType.getTypeDef() != typeDef) {
if (emitError) {
StringRef dialectName = typeDef->getDialect()->getNamespace();
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp
index 9f38b0caddb7..ecdadd3062d3 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp
@@ -25,11 +25,11 @@ using namespace mlir;
/// Attempt to extract a filename for the given loc.
static FileLineColLoc extractFileLoc(Location loc) {
- if (auto fileLoc = loc.dyn_cast<FileLineColLoc>())
+ if (auto fileLoc = dyn_cast<FileLineColLoc>(loc))
return fileLoc;
- if (auto nameLoc = loc.dyn_cast<NameLoc>())
+ if (auto nameLoc = dyn_cast<NameLoc>(loc))
return extractFileLoc(nameLoc.getChildLoc());
- if (auto opaqueLoc = loc.dyn_cast<OpaqueLoc>())
+ if (auto opaqueLoc = dyn_cast<OpaqueLoc>(loc))
return extractFileLoc(opaqueLoc.getFallbackLocation());
return FileLineColLoc();
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index 1936a53201f2..02909bb69977 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -607,7 +607,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(
return diag;
Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
- if (getResult().getType().isa<TransformValueHandleTypeInterface>()) {
+ if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
results.setValues(cast<OpResult>(getResult()), result);
return DiagnosedSilenceableFailure::success();
}
@@ -648,7 +648,7 @@ transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op,
LogicalResult transform::MatchStructuredResultOp::verify() {
if ((getAny() || getSingle()) ^
- getResult().getType().isa<TransformHandleTypeInterface>()) {
+ isa<TransformHandleTypeInterface>(getResult().getType())) {
return emitOpError() << "expects either the any/single keyword or the type "
"value handle result type";
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 74eb3a2df0f9..ea8d285cf52b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -87,7 +87,7 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPDLOperations(
SmallVector<OpFoldResult> &result, ArrayRef<OpFoldResult> ofrs) {
for (OpFoldResult ofr : ofrs) {
if (ofr.is<Attribute>()) {
- if (!ofr.get<Attribute>().isa<IntegerAttr>())
+ if (!isa<IntegerAttr>(ofr.get<Attribute>()))
return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
result.push_back(ofr);
continue;
@@ -155,7 +155,7 @@ transform::BufferizeToAllocationOp::apply(transform::TransformResults &results,
llvm::map_range(state.getPayloadValues(getTarget()), [&](Value v) {
return linalg::bufferizeToAllocation(rewriter, v, memorySpace);
}));
- results.setValues(getTransformed().cast<OpResult>(), transformed);
+ results.setValues(cast<OpResult>(getTransformed()), transformed);
return DiagnosedSilenceableFailure::success();
}
@@ -276,7 +276,7 @@ static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result,
if (!sizesAttr)
return parser.emitError(opLoc)
<< "expected '" << sizesAttrName << "' attribute";
- auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
+ auto sizesArrayAttr = dyn_cast<ArrayAttr>(sizesAttr);
if (!sizesArrayAttr)
return parser.emitError(opLoc)
<< "'" << sizesAttrName << "' attribute must be an array";
@@ -389,7 +389,7 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
// Tile the producer.
int64_t resultNumber =
- sliceOpToTile.getSource().cast<OpResult>().getResultNumber();
+ cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
FailureOr<TilingResult> tileAndFuseResult =
@@ -411,10 +411,7 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
// Replace the extract op.
auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
- sliceOpToTile->getResult(0)
- .getType()
- .cast<RankedTensorType>()
- .getShape());
+ cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
assert(succeeded(maybeRankReduced) && "unexpected shape");
rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
return tileAndFuseResult->tiledOps;
@@ -482,7 +479,7 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
// Replace the use in the tileableProducer before tiling: clone, replace and
// then tile.
- int64_t resultNumber = pUse->get().cast<OpResult>().getResultNumber();
+ int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
// Gather destination tensors.
@@ -516,10 +513,7 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
// Replace the extract op.
auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
- sliceOpToTile->getResult(0)
- .getType()
- .cast<RankedTensorType>()
- .getShape());
+ cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
assert(succeeded(maybeRankReduced) && "unexpected shape");
rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
@@ -568,7 +562,7 @@ static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
// TODO: Generalize to other type of ops.
assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
"Parallel insert slice is not a valid clone destination");
- unsigned resultNumber = use->get().cast<OpResult>().getResultNumber();
+ unsigned resultNumber = cast<OpResult>(use->get()).getResultNumber();
LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
OpBuilder::InsertionGuard guard(rewriter);
@@ -587,8 +581,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
ArrayRef<Operation *> producerOps = state.getPayloadOps(getProducerOp());
// If nothing to fuse, propagate success.
if (producerOps.empty()) {
- results.set(getFusedOp().cast<OpResult>(),
- SmallVector<mlir::Operation *>{});
+ results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
return DiagnosedSilenceableFailure::success();
}
ArrayRef<Operation *> containingOps = state.getPayloadOps(getContainingOp());
@@ -671,7 +664,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}
- results.set(getFusedOp().cast<OpResult>(), fusedOps);
+ results.set(cast<OpResult>(getFusedOp()), fusedOps);
return DiagnosedSilenceableFailure::success();
}
@@ -865,7 +858,7 @@ transform::MatchOp::apply(transform::TransformResults &results,
};
payloadOps.front()->walk(matchFun);
- results.set(getResult().cast<OpResult>(), res);
+ results.set(cast<OpResult>(getResult()), res);
return DiagnosedSilenceableFailure::success();
}
@@ -901,7 +894,7 @@ static ParseResult parseMultitileSizesTypes(OpAsmParser &parser,
DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
LinalgOp target, transform::ApplyToEachResultList &results,
TransformState &state) {
- if (getLowSize().getType().isa<TransformParamTypeInterface>()) {
+ if (isa<TransformParamTypeInterface>(getLowSize().getType())) {
if (target.hasDynamicShape()) {
auto diag = emitSilenceableError()
<< "cannot compute parametric tile sizes for dynamically "
@@ -923,7 +916,7 @@ DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
spec->lowTileSize * spec->lowTripCount}),
[&builder, this](int64_t value) {
return builder.getIntegerAttr(
- getLowSize().getType().cast<ParamType>().getType(), value);
+ cast<ParamType>(getLowSize().getType()).getType(), value);
}));
return DiagnosedSilenceableFailure::success();
}
@@ -958,7 +951,7 @@ void transform::MultiTileSizesOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getTarget(), effects);
producesHandle(getResults(), effects);
- if (getLowSize().getType().isa<TransformParamTypeInterface>())
+ if (isa<TransformParamTypeInterface>(getLowSize().getType()))
onlyReadsPayload(effects);
else
modifiesPayload(effects);
@@ -1006,7 +999,7 @@ transform::PackOp::apply(transform::TransformResults &transformResults,
ArrayRef<Operation *> targetOps = state.getPayloadOps(getTarget());
// If nothing to pack, propagate success.
if (targetOps.empty()) {
- transformResults.set(getPackedOp().cast<OpResult>(), {});
+ transformResults.set(cast<OpResult>(getPackedOp()), {});
return DiagnosedSilenceableFailure::success();
}
// Fail on multi-op handles.
@@ -1036,7 +1029,7 @@ transform::PackOp::apply(transform::TransformResults &transformResults,
if (failed(maybeResult))
return emitDefiniteFailure("data tiling failed");
- transformResults.set(getPackedOp().cast<OpResult>(),
+ transformResults.set(cast<OpResult>(getPackedOp()),
maybeResult->packedLinalgOp.getOperation());
return DiagnosedSilenceableFailure::success();
}
@@ -1242,7 +1235,7 @@ PackGreedilyOp::apply(transform::TransformResults &transformResults,
}
results.push_back(linalgOp);
}
- transformResults.set(getPackedOp().cast<OpResult>(), results);
+ transformResults.set(cast<OpResult>(getPackedOp()), results);
return DiagnosedSilenceableFailure::success();
}
@@ -1322,9 +1315,9 @@ transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
ArrayRef<Operation *> linalgOps = state.getPayloadOps(getTargetLinalgOp());
// Step 1. If nothing to pack, propagate success.
if (packOrUnpackOps.empty()) {
- transformResults.set(getPackedOp().cast<OpResult>(), {});
- transformResults.set(getPackOp().cast<OpResult>(), {});
- transformResults.set(getUnPackOp().cast<OpResult>(), {});
+ transformResults.set(cast<OpResult>(getPackedOp()), {});
+ transformResults.set(cast<OpResult>(getPackOp()), {});
+ transformResults.set(cast<OpResult>(getUnPackOp()), {});
return DiagnosedSilenceableFailure::success();
}
@@ -1366,7 +1359,7 @@ transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
if (unPackOp) {
assert(!packOp && "packOp must be null on entry when unPackOp is not null");
OpOperand *packUse = linalgOp.getDpsInitOperand(
- unPackOp.getSource().cast<OpResult>().getResultNumber());
+ cast<OpResult>(unPackOp.getSource()).getResultNumber());
packOp = dyn_cast_or_null<tensor::PackOp>(packUse->get().getDefiningOp());
if (!packOp || !packOp.getResult().hasOneUse())
return emitSilenceableError() << "could not find matching pack op";
@@ -1400,14 +1393,14 @@ transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
assert(succeeded(res) && "unexpected packTranspose failure");
// Step 4. Return results.
- transformResults.set(getPackOp().cast<OpResult>(), {res->transposedPackOp});
- transformResults.set(getPackedOp().cast<OpResult>(),
+ transformResults.set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
+ transformResults.set(cast<OpResult>(getPackedOp()),
{res->transposedLinalgOp});
if (unPackOp) {
- transformResults.set(getUnPackOp().cast<OpResult>(),
+ transformResults.set(cast<OpResult>(getUnPackOp()),
{res->transposedUnPackOp});
} else {
- transformResults.set(getUnPackOp().cast<OpResult>(), {});
+ transformResults.set(cast<OpResult>(getUnPackOp()), {});
}
return DiagnosedSilenceableFailure::success();
@@ -1430,14 +1423,14 @@ transform::PadOp::applyToOne(LinalgOp target,
SmallVector<Attribute> paddingValues;
for (auto const &it :
llvm::zip(getPaddingValues(), target->getOperandTypes())) {
- auto attr = std::get<0>(it).dyn_cast<TypedAttr>();
+ auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
if (!attr) {
emitOpError("expects padding values to be typed attributes");
return DiagnosedSilenceableFailure::definiteFailure();
}
Type elementType = getElementTypeOrSelf(std::get<1>(it));
// Try to parse string attributes to obtain an attribute of element type.
- if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
+ if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
auto parsedAttr = dyn_cast_if_present<TypedAttr>(
parseAttribute(stringAttr, getContext(), elementType,
/*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
@@ -1462,9 +1455,9 @@ transform::PadOp::applyToOne(LinalgOp target,
// Extract the transpose vectors.
SmallVector<SmallVector<int64_t>> transposePaddings;
- for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>())
+ for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
transposePaddings.push_back(
- extractFromI64ArrayAttr(transposeVector.cast<ArrayAttr>()));
+ extractFromI64ArrayAttr(cast<ArrayAttr>(transposeVector)));
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
@@ -1549,13 +1542,13 @@ DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
return emitDefiniteFailure() << "could not build packing loop nest";
if (result->clonedLoopIvs.empty()) {
- transformResults.set(getPackingLoop().cast<OpResult>(),
+ transformResults.set(cast<OpResult>(getPackingLoop()),
result->hoistedPadOp.getOperation());
return DiagnosedSilenceableFailure::success();
}
auto outerPackedLoop =
scf::getForInductionVarOwner(result->clonedLoopIvs.front());
- transformResults.set(getPackingLoop().cast<OpResult>(),
+ transformResults.set(cast<OpResult>(getPackingLoop()),
outerPackedLoop.getOperation());
return DiagnosedSilenceableFailure::success();
}
@@ -1643,7 +1636,7 @@ transform::PromoteOp::applyToOne(LinalgOp target,
if (mapping.size() > 1)
return emitDefaultDefiniteFailure(target);
- auto addressSpace = mapping[0].cast<gpu::GPUMemorySpaceMappingAttr>();
+ auto addressSpace = cast<gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
if (addressSpace.getAddressSpace() ==
gpu::GPUDialect::getWorkgroupAddressSpace()) {
@@ -1711,7 +1704,7 @@ transform::ReplaceOp::apply(TransformResults &transformResults,
rewriter.replaceOp(target, replacement->getResults());
replacements.push_back(replacement);
}
- transformResults.set(getReplacement().cast<OpResult>(), replacements);
+ transformResults.set(cast<OpResult>(getReplacement()), replacements);
return DiagnosedSilenceableFailure::success();
}
@@ -1828,7 +1821,7 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
splitPoints.reserve(payload.size());
if (getDynamicSplitPoint()) {
auto diag = DiagnosedSilenceableFailure::success();
- if (getDynamicSplitPoint().getType().isa<TransformHandleTypeInterface>()) {
+ if (isa<TransformHandleTypeInterface>(getDynamicSplitPoint().getType())) {
splitPoints = llvm::to_vector(llvm::map_range(
state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
if (op->getNumResults() != 1 ||
@@ -1909,8 +1902,8 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
return diag;
}
- results.set(getFirst().cast<OpResult>(), first);
- results.set(getSecond().cast<OpResult>(), second);
+ results.set(cast<OpResult>(getFirst()), first);
+ results.set(cast<OpResult>(getSecond()), second);
return DiagnosedSilenceableFailure::success();
}
@@ -2212,12 +2205,12 @@ transform::TileOp::apply(TransformResults &transformResults,
dynamicSizeProducers.reserve(getDynamicSizes().size());
paramSizes.reserve(getDynamicSizes().size());
for (Value transformValue : getDynamicSizes()) {
- if (transformValue.getType().isa<ParamType>()) {
+ if (isa<ParamType>(transformValue.getType())) {
dynamicSizeProducers.push_back({});
ArrayRef<Attribute> params = state.getParams(transformValue);
paramSizes.push_back(
llvm::to_vector(llvm::map_range(params, [](Attribute attr) {
- return attr.cast<IntegerAttr>().getValue().getSExtValue();
+ return cast<IntegerAttr>(attr).getValue().getSExtValue();
})));
if (paramSizes.back().size() != targets.size()) {
@@ -2247,7 +2240,7 @@ transform::TileOp::apply(TransformResults &transformResults,
for (Operation *op : dynamicSizeProducers.back()) {
if (op->getNumResults() == 1 &&
- op->getResult(0).getType().isa<IndexType>())
+ isa<IndexType>(op->getResult(0).getType()))
continue;
DiagnosedSilenceableFailure diag =
@@ -2283,7 +2276,7 @@ transform::TileOp::apply(TransformResults &transformResults,
for (OpFoldResult ofr : getMixedSizes()) {
if (auto attr = ofr.dyn_cast<Attribute>()) {
sizes.push_back(b.create<arith::ConstantIndexOp>(
- getLoc(), attr.cast<IntegerAttr>().getInt()));
+ getLoc(), cast<IntegerAttr>(attr).getInt()));
continue;
}
ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
@@ -2320,9 +2313,9 @@ transform::TileOp::apply(TransformResults &transformResults,
loops[en2.index()].push_back(en2.value());
}
- transformResults.set(getTiledLinalgOp().cast<OpResult>(), tiled);
+ transformResults.set(cast<OpResult>(getTiledLinalgOp()), tiled);
for (const auto &en : llvm::enumerate(loops))
- transformResults.set(getLoops()[en.index()].cast<OpResult>(), en.value());
+ transformResults.set(cast<OpResult>(getLoops()[en.index()]), en.value());
return DiagnosedSilenceableFailure::success();
}
@@ -2582,8 +2575,8 @@ transform::TileToForallOp::apply(transform::TransformResults &transformResults,
tiledOps.push_back(tilingResult.tiledOp);
}
- transformResults.set(getForallOp().cast<OpResult>(), tileOps);
- transformResults.set(getTiledOp().cast<OpResult>(), tiledOps);
+ transformResults.set(cast<OpResult>(getForallOp()), tileOps);
+ transformResults.set(cast<OpResult>(getTiledOp()), tiledOps);
return DiagnosedSilenceableFailure::success();
}
@@ -2678,7 +2671,7 @@ transform::TileToScfForOp::apply(TransformResults &transformResults,
for (Operation *op : dynamicSizeProducers.back()) {
if (op->getNumResults() == 1 &&
- op->getResult(0).getType().isa<IndexType>())
+ isa<IndexType>(op->getResult(0).getType()))
continue;
DiagnosedSilenceableFailure diag =
emitSilenceableError() << "expected sizes to be produced by ops "
@@ -2712,7 +2705,7 @@ transform::TileToScfForOp::apply(TransformResults &transformResults,
for (OpFoldResult ofr : getMixedSizes()) {
if (auto attr = ofr.dyn_cast<Attribute>()) {
sizes.push_back(b.create<arith::ConstantIndexOp>(
- getLoc(), attr.cast<IntegerAttr>().getInt()));
+ getLoc(), cast<IntegerAttr>(attr).getInt()));
} else {
sizes.push_back(
dynamicSizeProducers[dynamicIdx++][index]->getResult(0));
@@ -2737,9 +2730,9 @@ transform::TileToScfForOp::apply(TransformResults &transformResults,
loops[en2.index()].push_back(en2.value());
}
- transformResults.set(getTiledLinalgOp().cast<OpResult>(), tiled);
+ transformResults.set(cast<OpResult>(getTiledLinalgOp()), tiled);
for (const auto &en : llvm::enumerate(loops))
- transformResults.set(getLoops()[en.index()].cast<OpResult>(), en.value());
+ transformResults.set(cast<OpResult>(getLoops()[en.index()]), en.value());
return DiagnosedSilenceableFailure::success();
}
@@ -2899,7 +2892,7 @@ DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply(
for (OpFoldResult sz : getMixedVectorSizes()) {
if (sz.is<Attribute>()) {
auto attr = sz.get<Attribute>();
- vectorSizes.push_back(attr.cast<IntegerAttr>().getInt());
+ vectorSizes.push_back(cast<IntegerAttr>(attr).getInt());
continue;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
index 1a7d7a113b22..6b06c32d22eb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
@@ -64,20 +64,20 @@ public:
if (genericOp.getNumDpsInits() != 1)
return failure();
- auto outputType = genericOp.getResultTypes().front().dyn_cast<ShapedType>();
+ auto outputType = dyn_cast<ShapedType>(genericOp.getResultTypes().front());
// Require the output types to be static given that we are generating
// constants.
if (!outputType || !outputType.hasStaticShape())
return failure();
if (!llvm::all_of(genericOp.getInputs(), [](Value input) {
- return input.getType().isa<ShapedType>();
+ return isa<ShapedType>(input.getType());
}))
return failure();
// Make sure all element types are the same.
auto getOperandElementType = [](Value value) {
- return value.getType().cast<ShapedType>().getElementType();
+ return cast<ShapedType>(value.getType()).getElementType();
};
if (!llvm::all_equal(
llvm::map_range(genericOp->getOperands(), getOperandElementType)))
@@ -138,7 +138,7 @@ public:
// unify the following cases but they have lifetime as the MLIRContext.
SmallVector<APInt> intOutputValues;
SmallVector<APFloat> fpOutputValues;
- if (elementType.template isa<FloatType>())
+ if (isa<FloatType>(elementType))
fpOutputValues.resize(numElements, APFloat(0.f));
else
intOutputValues.resize(numElements);
@@ -174,7 +174,7 @@ public:
auto inputShapes = llvm::to_vector<4>(
llvm::map_range(genericOp.getInputs(), [](Value value) {
- return value.getType().cast<ShapedType>().getShape();
+ return cast<ShapedType>(value.getType()).getShape();
}));
// Given a `linearIndex`, remap it to a linear index to access linalg op
@@ -205,7 +205,7 @@ public:
}
};
- bool isFloat = elementType.isa<FloatType>();
+ bool isFloat = isa<FloatType>(elementType);
if (isFloat) {
SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges;
for (int i = 0; i < numInputs; ++i)
@@ -282,7 +282,7 @@ struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
// The yield op should return the block argument corresponds to the input.
for (Value yieldVal : yieldOp.getValues()) {
- auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
+ auto yieldArg = dyn_cast<BlockArgument>(yieldVal);
if (!yieldArg || yieldArg.getOwner() != &body)
return nullptr;
if (yieldArg.getArgNumber() != 0)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
index 5423cf8d750f..48c24598f628 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp
@@ -29,7 +29,7 @@ static bool hasAllOneValues(DenseIntElementsAttr attr) {
}
static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) {
- bool isInt = x.getType().isa<IntegerType>();
+ bool isInt = isa<IntegerType>(x.getType());
if (isInt)
return builder.create<arith::AddIOp>(loc, x, y);
return builder.create<arith::AddFOp>(loc, x, y);
@@ -42,7 +42,7 @@ static Value createMul(Location loc, Value x, Value y, Type accType,
convertScalarToDtype(builder, loc, x, accType, /*isUnsignedCast=*/false);
Value yConvert =
convertScalarToDtype(builder, loc, y, accType, /*isUnsignedCast=*/false);
- if (accType.isa<IntegerType>())
+ if (isa<IntegerType>(accType))
return builder.create<arith::MulIOp>(loc, xConvert, yConvert);
return builder.create<arith::MulFOp>(loc, xConvert, yConvert);
}
@@ -74,9 +74,9 @@ static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex,
FailureOr<std::pair<Operation *, Operation *>>
rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
- auto inputType = convOp.getInputs()[0].getType().cast<ShapedType>();
- auto filterType = convOp.getInputs()[1].getType().cast<ShapedType>();
- auto outputType = convOp.getOutputs()[0].getType().cast<ShapedType>();
+ auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
+ auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
+ auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
if (!filterType.hasStaticShape())
return rewriter.notifyMatchFailure(
@@ -210,9 +210,9 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
FailureOr<std::pair<Operation *, Operation *>>
rewriteInIm2Col(RewriterBase &rewriter,
linalg::DepthwiseConv2DNhwcHwcOp convOp) {
- auto inputType = convOp.getInputs()[0].getType().cast<RankedTensorType>();
- auto filterType = convOp.getInputs()[1].getType().cast<RankedTensorType>();
- auto outputType = convOp.getOutputs()[0].getType().cast<RankedTensorType>();
+ auto inputType = cast<RankedTensorType>(convOp.getInputs()[0].getType());
+ auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType());
+ auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType());
if (!filterType.hasStaticShape())
return rewriter.notifyMatchFailure(
@@ -230,7 +230,7 @@ rewriteInIm2Col(RewriterBase &rewriter,
Location loc = convOp.getLoc();
auto transposeOperand = [&](Value operand, ArrayRef<int64_t> indices) {
- auto operandTensorType = operand.getType().cast<RankedTensorType>();
+ auto operandTensorType = cast<RankedTensorType>(operand.getType());
auto nloops = indices.size();
ArrayRef<int64_t> inputShape = operandTensorType.getShape();
@@ -272,7 +272,7 @@ rewriteInIm2Col(RewriterBase &rewriter,
Value inputT = transposeOperand(input, {0, 3, 1, 2});
Value filterT = transposeOperand(filter, {2, 0, 1});
ArrayRef<int64_t> filterTShape =
- filterT.getType().cast<RankedTensorType>().getShape();
+ cast<RankedTensorType>(filterT.getType()).getShape();
ArrayRef<int64_t> outputShape = outputType.getShape();
int n = outputShape[0];
@@ -360,9 +360,9 @@ rewriteInIm2Col(RewriterBase &rewriter,
FailureOr<std::pair<Operation *, Operation *>>
rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
- auto inputType = convOp.getInputs()[0].getType().cast<ShapedType>();
- auto filterType = convOp.getInputs()[1].getType().cast<ShapedType>();
- auto outputType = convOp.getOutputs()[0].getType().cast<ShapedType>();
+ auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
+ auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
+ auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
if (!filterType.hasStaticShape())
return rewriter.notifyMatchFailure(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 3ec5094ed90b..a81a48df00b6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -66,12 +66,12 @@ static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter,
Attribute constYieldedValue;
// Is the yielded value a bbArg defined outside of the PadOp?
bool outsideBbArg =
- yieldedValue.isa<BlockArgument>() &&
- yieldedValue.cast<BlockArgument>().getOwner()->getParentOp() !=
+ isa<BlockArgument>(yieldedValue) &&
+ cast<BlockArgument>(yieldedValue).getOwner()->getParentOp() !=
padOp.getOperation();
// Is the yielded value an OpResult defined outside of the PadOp?
bool outsideOpResult =
- yieldedValue.isa<OpResult>() &&
+ isa<OpResult>(yieldedValue) &&
yieldedValue.getDefiningOp()->getParentOp() != padOp.getOperation();
bool invariantYieldedValue = outsideBbArg || outsideOpResult;
if (matchPattern(yieldedValue, m_Constant(&constYieldedValue))) {
@@ -120,19 +120,19 @@ static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter,
static SmallVector<Value> reifyOrComputeDynamicSizes(OpBuilder &b,
Value value) {
- auto tensorType = value.getType().cast<RankedTensorType>();
+ auto tensorType = cast<RankedTensorType>(value.getType());
if (tensorType.hasStaticShape())
return {};
// Try to reify dynamic sizes.
ReifiedRankedShapedTypeDims reifiedShape;
- if (value.isa<OpResult>() &&
+ if (isa<OpResult>(value) &&
succeeded(reifyResultShapes(b, value.getDefiningOp(), reifiedShape))) {
SmallVector<Value> dynSizes;
for (int64_t i = 0; i < tensorType.getRank(); ++i) {
if (tensorType.isDynamicDim(i))
dynSizes.push_back(
- reifiedShape[value.cast<OpResult>().getResultNumber()][i]
+ reifiedShape[cast<OpResult>(value).getResultNumber()][i]
.get<Value>());
}
return dynSizes;
@@ -153,12 +153,12 @@ static Value createAllocationForTensor(RewriterBase &rewriter, Location loc,
Value value,
Attribute memorySpace = {}) {
OpBuilder::InsertionGuard g(rewriter);
- auto tensorType = value.getType().cast<RankedTensorType>();
+ auto tensorType = cast<RankedTensorType>(value.getType());
// Create buffer allocation.
- auto memrefType = bufferization::getMemRefTypeWithStaticIdentityLayout(
- tensorType, memorySpace)
- .cast<MemRefType>();
+ auto memrefType =
+ cast<MemRefType>(bufferization::getMemRefTypeWithStaticIdentityLayout(
+ tensorType, memorySpace));
SmallVector<Value> dynamicSizes = reifyOrComputeDynamicSizes(rewriter, value);
Value alloc = rewriter.create<memref::AllocOp>(loc, memrefType, dynamicSizes);
@@ -206,7 +206,7 @@ FailureOr<Operation *> mlir::linalg::rewriteInDestinationPassingStyle(
RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) {
Location loc = fromElementsOp.getLoc();
RankedTensorType tensorType =
- fromElementsOp.getType().cast<RankedTensorType>();
+ cast<RankedTensorType>(fromElementsOp.getType());
auto shape = tensorType.getShape();
// Create tensor.empty.
@@ -247,7 +247,7 @@ mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter,
return failure();
Location loc = generateOp.getLoc();
- RankedTensorType tensorType = generateOp.getType().cast<RankedTensorType>();
+ RankedTensorType tensorType = cast<RankedTensorType>(generateOp.getType());
// Create tensor.empty.
auto emptyOp =
@@ -339,7 +339,7 @@ Value linalg::bufferizeToAllocation(RewriterBase &rewriter, Value value,
llvm::map_range(value.getUses(), [](OpOperand &use) { return &use; }));
OpBuilder::InsertionGuard g(rewriter);
- if (auto bbArg = value.dyn_cast<BlockArgument>()) {
+ if (auto bbArg = dyn_cast<BlockArgument>(value)) {
rewriter.setInsertionPointToStart(bbArg.getOwner());
} else {
rewriter.setInsertionPointAfter(value.getDefiningOp());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index e5764cb30e03..1ddd8b144c60 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -640,7 +640,7 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) {
auto loc = genericOp.getLoc();
Value unPackDest = producerUnPackOp.getDest();
auto genericOutType =
- genericOp.getDpsInitOperand(0)->get().getType().cast<RankedTensorType>();
+ cast<RankedTensorType>(genericOp.getDpsInitOperand(0)->get().getType());
if (producerUnPackOp.getDestType() != genericOutType ||
!genericOutType.hasStaticShape()) {
unPackDest = tensor::UnPackOp::createDestinationTensor(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
index e381b0aa011c..42f87a16c92f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
@@ -132,12 +132,12 @@ SmallVector<OpFoldResult> permuteValues(ArrayRef<OpFoldResult> values,
static Value getZero(OpBuilder &b, Location loc, Type elementType) {
assert(elementType.isIntOrIndexOrFloat() &&
"expected scalar type while computing zero value");
- if (elementType.isa<IntegerType>())
+ if (isa<IntegerType>(elementType))
return b.create<arith::ConstantIntOp>(loc, 0, elementType);
if (elementType.isIndex())
return b.create<arith::ConstantIndexOp>(loc, 0);
// Assume float.
- auto floatType = elementType.cast<FloatType>();
+ auto floatType = cast<FloatType>(elementType);
return b.create<arith::ConstantFloatOp>(
loc, APFloat::getZero(floatType.getFloatSemantics()), floatType);
}
@@ -179,7 +179,7 @@ DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
if (resultNumber) {
newInitValues.push_back(
genericOp.getDpsInitOperand(*resultNumber)->get());
- OpResult result = genericOp.getResult(*resultNumber).cast<OpResult>();
+ OpResult result = cast<OpResult>(genericOp.getResult(*resultNumber));
newResultTypes.push_back(result.getType());
peeledGenericOpIndexingMaps.push_back(
genericOp.getIndexingMapMatchingResult(result));
@@ -231,7 +231,7 @@ DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
}));
for (auto resultNum :
llvm::seq<unsigned>(origNumResults, peeledGenericOpNumResults)) {
- OpResult result = peeledGenericOp.getResult(resultNum).cast<OpResult>();
+ OpResult result = cast<OpResult>(peeledGenericOp.getResult(resultNum));
indexingMaps.push_back(
peeledGenericOp.getIndexingMapMatchingResult(result));
}
@@ -348,7 +348,7 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
/// the peeled operation.
SmallVector<Value> replacements;
for (const auto &yieldValue : llvm::enumerate(yieldOp->getOperands())) {
- OpResult opr = yieldValue.value().dyn_cast<OpResult>();
+ OpResult opr = dyn_cast<OpResult>(yieldValue.value());
if (!opr || opr.getOwner() != peeledScalarOperation)
replacements.push_back(residualGenericOp.getResult(yieldValue.index()));
else
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 5fd48853875c..bf91a708ae15 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -32,7 +32,7 @@ static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
ValueRange inputs, Location loc) {
assert(inputs.size() == 1);
auto inputType = inputs[0].getType();
- if (inputType.isa<TensorType>())
+ if (isa<TensorType>(inputType))
return nullptr;
// A detensored value is converted back by creating a new tensor from its
@@ -320,9 +320,9 @@ struct LinalgDetensorize
// * Add the argument to blockArgsToDetensor.
// * Walk the use-def chain backwards to add each predecessor's
// terminator-operands corresponding to currentItem to workList.
- if (currentItem.dyn_cast<BlockArgument>()) {
+ if (dyn_cast<BlockArgument>(currentItem)) {
BlockArgument currentItemBlockArgument =
- currentItem.cast<BlockArgument>();
+ cast<BlockArgument>(currentItem);
Block *ownerBlock = currentItemBlockArgument.getOwner();
// Function arguments are not detensored/converted.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 4a2c0a64fc07..d8eccb967589 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -308,7 +308,7 @@ struct MoveInitOperandsToInput : public OpRewritePattern<GenericOp> {
for (OpOperand *op : candidates) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfterValue(op->get());
- auto elemType = op->get().getType().cast<ShapedType>().getElementType();
+ auto elemType = cast<ShapedType>(op->get().getType()).getElementType();
auto empty = rewriter.create<tensor::EmptyOp>(
loc, tensor::createDimValues(rewriter, loc, op->get()), elemType);
@@ -387,7 +387,7 @@ replaceUnitExtents(GenericOp genericOp, OpOperand *opOperand,
// Early return for memrefs with affine maps to represent that we will always
// leave them unchanged.
Type actualType = opOperand->get().getType();
- if (auto memref = actualType.dyn_cast<MemRefType>()) {
+ if (auto memref = dyn_cast<MemRefType>(actualType)) {
if (!memref.getLayout().isIdentity())
return std::nullopt;
}
@@ -437,7 +437,7 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
ArrayRef<ReassociationIndices> reassociation, Location loc,
PatternRewriter &rewriter) const {
// There are no results for memref outputs.
- auto origResultType = origOutput.getType().cast<RankedTensorType>();
+ auto origResultType = cast<RankedTensorType>(origOutput.getType());
if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) {
unsigned rank = origResultType.getRank();
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
@@ -459,7 +459,7 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
Value collapseValue(Value operand, ArrayRef<int64_t> targetShape,
ArrayRef<ReassociationIndices> reassociation,
Location loc, PatternRewriter &rewriter) const {
- if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
+ if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) {
FailureOr<Value> rankReducingExtract =
memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
@@ -478,7 +478,7 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
return rewriter.create<memref::CollapseShapeOp>(loc, targetType, operand,
reassociation);
}
- if (auto tensorType = operand.getType().dyn_cast<RankedTensorType>()) {
+ if (auto tensorType = dyn_cast<RankedTensorType>(operand.getType())) {
if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) {
FailureOr<Value> rankReducingExtract =
tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
@@ -502,7 +502,7 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
PatternRewriter &rewriter) const override {
// Skip the pattern if the op has any tensor with special encoding.
if (llvm::any_of(genericOp->getOperandTypes(), [](Type type) {
- auto tensorType = type.dyn_cast<RankedTensorType>();
+ auto tensorType = dyn_cast<RankedTensorType>(type);
return tensorType && tensorType.getEncoding() != nullptr;
}))
return failure();
@@ -607,11 +607,10 @@ struct RankReducedExtractSliceOp
if (!reassociation ||
reassociation->size() == static_cast<size_t>(resultType.getRank()))
return failure();
- auto rankReducedType =
+ auto rankReducedType = cast<RankedTensorType>(
tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
- strides)
- .cast<RankedTensorType>();
+ strides));
Location loc = sliceOp.getLoc();
Value newSlice = rewriter.create<tensor::ExtractSliceOp>(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index bf728a6ec319..33ff4a3ecc09 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -87,7 +87,7 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
// type. Producer must have full tensor semantics to avoid potential
// aliasing between producer and consumer memrefs.
if (!producer.hasTensorSemantics() ||
- !fusedOperand->get().getType().isa<RankedTensorType>())
+ !isa<RankedTensorType>(fusedOperand->get().getType()))
return false;
// Verify that
@@ -232,14 +232,14 @@ static void generateFusedElementwiseOpRegion(
// forward the yield operand.
auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
unsigned producerResultNumber =
- fusedOperand->get().cast<OpResult>().getResultNumber();
+ cast<OpResult>(fusedOperand->get()).getResultNumber();
Value replacement =
mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber));
// Sanity checks, if replacement is not already in the mapper then it must be
// produced outside.
if (replacement == producerYieldOp.getOperand(producerResultNumber)) {
- if (auto bb = replacement.dyn_cast<BlockArgument>())
+ if (auto bb = dyn_cast<BlockArgument>(replacement))
assert(bb.getOwner() != &producerBlock &&
"yielded block argument must have been mapped");
else
@@ -278,7 +278,7 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
OpOperand *fusedOperand) {
assert(areElementwiseOpsFusable(fusedOperand) &&
"expected elementwise operation pre-conditions to pass");
- auto producerResult = fusedOperand->get().cast<OpResult>();
+ auto producerResult = cast<OpResult>(fusedOperand->get());
auto producer = cast<GenericOp>(producerResult.getOwner());
auto consumer = cast<GenericOp>(fusedOperand->getOwner());
// TODO: allow fusing the producer of an output operand.
@@ -357,7 +357,7 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
fusedOutputOperands.push_back(opOperand->get());
fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand));
Type resultType = opOperand->get().getType();
- if (!resultType.isa<MemRefType>())
+ if (!isa<MemRefType>(resultType))
fusedResultTypes.push_back(resultType);
}
@@ -512,7 +512,7 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
return genericOp.hasTensorSemantics() &&
llvm::all_of(genericOp.getIndexingMaps().getValue(),
[](Attribute attr) {
- return attr.cast<AffineMapAttr>()
+ return cast<AffineMapAttr>(attr)
.getValue()
.isProjectedPermutation();
}) &&
@@ -776,7 +776,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
continue;
}
if (auto opOperandType =
- opOperand->get().getType().dyn_cast<RankedTensorType>()) {
+ dyn_cast<RankedTensorType>(opOperand->get().getType())) {
AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
RankedTensorType expandedOperandType =
getExpandedType(opOperandType, indexingMap, expansionInfo);
@@ -805,7 +805,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
SmallVector<Value> outputs;
for (OpOperand *opOperand : genericOp.getDpsInitOperands()) {
AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
- auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
+ auto opOperandType = cast<RankedTensorType>(opOperand->get().getType());
RankedTensorType expandedOutputType =
getExpandedType(opOperandType, indexingMap, expansionInfo);
if (expandedOutputType != opOperand->get().getType()) {
@@ -921,7 +921,7 @@ struct FoldReshapeWithGenericOpByExpansion
LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
PatternRewriter &rewriter) const override {
// Fold only if all constraints of fusing with reshape by expansion are met.
- auto producerResult = reshapeOp.getSrc().dyn_cast<OpResult>();
+ auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
if (!producerResult) {
return rewriter.notifyMatchFailure(reshapeOp,
"source not produced by an operation");
@@ -959,8 +959,9 @@ struct FoldReshapeWithGenericOpByExpansion
// same type as the returns of the original generic op, the consumer reshape
// op can be replaced by the source of the collapse_shape op that defines
// the replacement.
- Value reshapeReplacement = (*replacementValues)
- [reshapeOp.getSrc().cast<OpResult>().getResultNumber()];
+ Value reshapeReplacement =
+ (*replacementValues)[cast<OpResult>(reshapeOp.getSrc())
+ .getResultNumber()];
if (auto collapseOp =
reshapeReplacement.getDefiningOp<tensor::CollapseShapeOp>()) {
reshapeReplacement = collapseOp.getSrc();
@@ -1447,7 +1448,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
.createLoopRanges(rewriter, genericOp.getLoc());
auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
if (auto attr = ofr.dyn_cast<Attribute>())
- return attr.cast<IntegerAttr>().getInt() == value;
+ return cast<IntegerAttr>(attr).getInt() == value;
llvm::APInt actual;
return matchPattern(ofr.get<Value>(), m_ConstantInt(&actual)) &&
actual.getSExtValue() == value;
@@ -1521,8 +1522,8 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseGenericOpIterationDims(
Value collapsedOpResult =
collapsedGenericOp->getResult(originalResult.index());
auto originalResultType =
- originalResult.value().getType().cast<ShapedType>();
- auto collapsedOpResultType = collapsedOpResult.getType().cast<ShapedType>();
+ cast<ShapedType>(originalResult.value().getType());
+ auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType());
if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
AffineMap indexingMap =
genericOp.getIndexingMapMatchingResult(originalResult.value());
@@ -1671,7 +1672,7 @@ public:
return false;
};
- auto resultValue = opOperand->get().dyn_cast<OpResult>();
+ auto resultValue = dyn_cast<OpResult>(opOperand->get());
if (!def || !resultValue || !isScalarOrSplatConstantOp(def))
continue;
@@ -1756,7 +1757,7 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
for (OpOperand *opOperand : op.getDpsInitOperands()) {
if (!op.payloadUsesValueFromOperand(opOperand)) {
Value operandVal = opOperand->get();
- auto operandType = operandVal.getType().dyn_cast<RankedTensorType>();
+ auto operandType = dyn_cast<RankedTensorType>(operandVal.getType());
if (!operandType)
continue;
@@ -1810,7 +1811,7 @@ struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
fillFound = true;
Value fillVal = fillOp.value();
auto resultType =
- fillOp.result().getType().cast<RankedTensorType>().getElementType();
+ cast<RankedTensorType>(fillOp.result().getType()).getElementType();
Value convertedVal =
convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType,
/*isUnsignedCast =*/false);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index 549764de593b..18026cc15033 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -28,7 +28,7 @@ static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
// TODO: The conversion pattern can be made to work for `any_of` here, but
// it's more complex as it requires tracking which operands are scalars.
return llvm::all_of(op->getOperandTypes(),
- [](Type type) { return type.isa<RankedTensorType>(); });
+ [](Type type) { return isa<RankedTensorType>(type); });
}
/// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
@@ -67,7 +67,7 @@ getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op) {
// Extract static / dynamic shape mix from the first operand.
Value firstOperand = operands.front();
- auto rankedTensorType = t.cast<RankedTensorType>();
+ auto rankedTensorType = cast<RankedTensorType>(t);
auto staticShape = llvm::to_vector<4>(rankedTensorType.getShape());
auto dynamicShape = linalg::createDynamicDimensions(b, loc, firstOperand);
@@ -87,7 +87,7 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
return rewriter.notifyMatchFailure(
op, "requires elementwise op on ranked tensors");
- auto rank = op->getResult(0).getType().cast<RankedTensorType>().getRank();
+ auto rank = cast<RankedTensorType>(op->getResult(0).getType()).getRank();
SmallVector<AffineMap, 3> indexingMaps(
op->getNumResults() + op->getNumOperands(),
rewriter.getMultiDimIdentityMap(rank));
@@ -104,7 +104,7 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
[&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
auto resultTypes = llvm::to_vector<6>(
llvm::map_range(op->getResultTypes(), [](Type type) {
- return type.cast<TensorType>().getElementType();
+ return cast<TensorType>(type).getElementType();
}));
auto *scalarOp =
builder.create(loc, op->getName().getIdentifier(),
diff --git a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
index defa02751758..c89fc5b9da8d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp
@@ -89,7 +89,7 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
Location loc = genericOp.getLoc();
SmallVector<Type> newResultTypes;
for (Value v : newOutputOperands)
- if (v.getType().isa<TensorType>())
+ if (isa<TensorType>(v.getType()))
newResultTypes.push_back(v.getType());
auto newOp = rewriter.create<GenericOp>(
loc, newResultTypes, newInputOperands, newOutputOperands,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp
index b6e2ffcbba36..703db8373c31 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp
@@ -86,12 +86,12 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> {
// result of the generic op. The low pad values are the offsets, the size of
// the source is the size of the slice.
// TODO: This insert/extract could be potentially made a utility method.
- unsigned resultNumber = source.cast<OpResult>().getResultNumber();
+ unsigned resultNumber = cast<OpResult>(source).getResultNumber();
SmallVector<OpFoldResult> offsets = padOp.getMixedLowPad();
SmallVector<OpFoldResult> sizes;
sizes.reserve(offsets.size());
- for (const auto &shape : llvm::enumerate(
- source.getType().cast<RankedTensorType>().getShape())) {
+ for (const auto &shape :
+ llvm::enumerate(cast<RankedTensorType>(source.getType()).getShape())) {
if (ShapedType::isDynamic(shape.value())) {
sizes.push_back(
rewriter.create<tensor::DimOp>(loc, source, shape.index())
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 6f9b60843d6d..cf3fd4ba0a0b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -151,7 +151,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
SmallVector<Type, 4> resultTypes;
resultTypes.reserve(producer->getNumResults());
for (OpOperand *operand : producer.getDpsInitOperands()) {
- auto tensorType = operand->get().getType().dyn_cast<RankedTensorType>();
+ auto tensorType = dyn_cast<RankedTensorType>(operand->get().getType());
if (!tensorType)
continue;
unsigned rank = tensorType.getRank();
@@ -210,20 +210,20 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap,
// dependence tracking since the dependence tracking is similar to what is done
// w.r.t to buffers.
static void getProducerOfTensor(Value tensor, OpResult &opResult) {
- if (!tensor.getType().isa<RankedTensorType>())
+ if (!isa<RankedTensorType>(tensor.getType()))
return;
while (true) {
LLVM_DEBUG(llvm::dbgs() << "\ngetProducerOfTensor: " << tensor);
if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) {
- opResult = tensor.cast<OpResult>();
+ opResult = cast<OpResult>(tensor);
return;
}
if (auto sliceOp = tensor.getDefiningOp<tensor::ExtractSliceOp>()) {
tensor = sliceOp.getSource();
continue;
}
- if (auto blockArg = tensor.dyn_cast<BlockArgument>()) {
+ if (auto blockArg = dyn_cast<BlockArgument>(tensor)) {
if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) {
tensor = *(forOp.getIterOperands().begin() + blockArg.getArgNumber());
continue;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index d8ecc807ea05..87aade3a3eec 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -227,7 +227,7 @@ SmallVector<BlockArgument> TileLoopNest::getTiedBBArgs(BlockArgument bbArg) {
return {};
bbArgs.push_back(bbArg);
OpOperand *iterArg = &tileLoop.getOpOperandForRegionIterArg(bbArg);
- bbArg = iterArg->get().dyn_cast<BlockArgument>();
+ bbArg = dyn_cast<BlockArgument>(iterArg->get());
}
// Reverse the block arguments to order them from outer to inner.
@@ -358,13 +358,13 @@ FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
// Check if the producer is a LinalgOp possibly passed by iteration argument.
OpOperand *iterArg = nullptr;
- auto producerResult = sliceOp.getSource().dyn_cast<OpResult>();
- if (auto bbArg = sliceOp.getSource().dyn_cast<BlockArgument>()) {
+ auto producerResult = dyn_cast<OpResult>(sliceOp.getSource());
+ if (auto bbArg = dyn_cast<BlockArgument>(sliceOp.getSource())) {
iterArg = getTiedIterArg(bbArg);
// Check the iteration argument may be used to pass in the producer output.
if (!iterArg || hasOtherUses(bbArg, sliceOp))
return failure();
- producerResult = iterArg->get().dyn_cast<OpResult>();
+ producerResult = dyn_cast<OpResult>(iterArg->get());
}
if (!producerResult || !isa<LinalgOp>(producerResult.getOwner()))
return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 251f7d8575f7..21d83d225d70 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -549,7 +549,7 @@ static FailureOr<PackingResult> buildPackingLoopNestImpl(
int paddedRank = paddedTensorType.getRank();
// Step 0. Populate bvm with opToHoist.getSource if relevant.
- BlockArgument bbArg = opToHoist.getSource().dyn_cast<BlockArgument>();
+ BlockArgument bbArg = dyn_cast<BlockArgument>(opToHoist.getSource());
while (bbArg) {
auto forOp = dyn_cast<scf::ForOp>(bbArg.getOwner()->getParentOp());
if (!forOp)
@@ -558,7 +558,7 @@ static FailureOr<PackingResult> buildPackingLoopNestImpl(
break;
OpOperand &operand = forOp.getOpOperandForRegionIterArg(bbArg);
bvm.map(bbArg, operand.get());
- bbArg = operand.get().dyn_cast<BlockArgument>();
+ bbArg = dyn_cast<BlockArgument>(operand.get());
}
// Step 1. iteratively clone loops and push `hoistedPackedTensor`.
@@ -754,9 +754,8 @@ static bool tracesBackToExpectedValue(tensor::ExtractSliceOp extractSliceOp,
if (!destOp)
break;
LLVM_DEBUG(DBGS() << "--step dest op: " << destOp << "\n");
- source =
- destOp.getDpsInitOperand(source.cast<OpResult>().getResultNumber())
- ->get();
+ source = destOp.getDpsInitOperand(cast<OpResult>(source).getResultNumber())
+ ->get();
}
LLVM_DEBUG(DBGS() << "--final source: " << source << "\n");
LLVM_DEBUG(DBGS() << "--expected source: " << expectedSource << "\n");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 13ec4d92ad26..01b893a0e0a5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -86,7 +86,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {
[&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
func.walk([&](vector::TransferReadOp transferRead) {
- if (!transferRead.getShapedType().isa<MemRefType>())
+ if (!isa<MemRefType>(transferRead.getShapedType()))
return WalkResult::advance();
LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 23c831f0a018..d91d8c4bf610 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -162,7 +162,7 @@ static void emitScalarImplementation(OpBuilder &b, Location loc,
SmallVector<SmallVector<Value>, 8> indexing;
SmallVector<Value> outputBuffers;
for (OpOperand *outputOperand : linalgOp.getDpsInitOperands()) {
- if (!outputOperand->get().getType().isa<MemRefType>())
+ if (!isa<MemRefType>(outputOperand->get().getType()))
continue;
indexing.push_back(makeCanonicalAffineApplies(
b, loc, linalgOp.getMatchingIndexingMap(outputOperand),
@@ -242,7 +242,7 @@ static FailureOr<LinalgLoops> linalgOpToLoopsImpl(RewriterBase &rewriter,
return failure();
// The induction variable is a block argument of the entry block of the
// loop operation.
- BlockArgument ivVal = iv.dyn_cast<BlockArgument>();
+ BlockArgument ivVal = dyn_cast<BlockArgument>(iv);
if (!ivVal)
return failure();
loopSet.insert(ivVal.getOwner()->getParentOp());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
index cabd342d86c0..93fa5ff24ac6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
@@ -44,9 +44,9 @@ matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel,
auto result = operation->getResult(0);
- auto kernelTy = kernel.getType().dyn_cast<RankedTensorType>();
- auto initTy = init.getType().dyn_cast<RankedTensorType>();
- auto resultTy = result.getType().template dyn_cast<RankedTensorType>();
+ auto kernelTy = dyn_cast<RankedTensorType>(kernel.getType());
+ auto initTy = dyn_cast<RankedTensorType>(init.getType());
+ auto resultTy = dyn_cast<RankedTensorType>(result.getType());
if (!kernelTy || !initTy || !resultTy)
return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index 4fcffea14e03..d39cd0e686e0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -292,9 +292,9 @@ promoteSubViews(ImplicitLocOpBuilder &b,
})
.Case([&](ComplexType t) {
Value tmp;
- if (auto et = t.getElementType().dyn_cast<FloatType>())
+ if (auto et = dyn_cast<FloatType>(t.getElementType()))
tmp = b.create<arith::ConstantOp>(FloatAttr::get(et, 0.0));
- else if (auto et = t.getElementType().cast<IntegerType>())
+ else if (auto et = cast<IntegerType>(t.getElementType()))
tmp = b.create<arith::ConstantOp>(IntegerAttr::get(et, 0));
return b.create<complex::CreateOp>(t, tmp, tmp);
})
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
index 344b2893c906..203ae437a2a5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp
@@ -93,7 +93,7 @@ linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension,
{iterationSpace[dimension].offset, iterationSpace[dimension].size,
minSplitPoint});
if (auto attr = remainingSize.dyn_cast<Attribute>()) {
- if (attr.cast<IntegerAttr>().getValue().isZero())
+ if (cast<IntegerAttr>(attr).getValue().isZero())
return {op, TilingInterface()};
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index b4d95b70de83..982b0243e953 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -113,7 +113,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
}
Type newType = RankedTensorType::get(
newShape,
- operand->get().getType().cast<RankedTensorType>().getElementType());
+ cast<RankedTensorType>(operand->get().getType()).getElementType());
Value newInput = b.create<tensor::ExpandShapeOp>(
loc, newType, operand->get(), reassociation);
newInputs.push_back(newInput);
@@ -309,7 +309,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
fillOps.reserve(op.getNumDpsInits());
for (auto it : llvm::zip(op.getDpsInitOperands(), neutralElements)) {
Value rankedTensor = std::get<0>(it)->get();
- auto t = rankedTensor.getType().cast<RankedTensorType>();
+ auto t = cast<RankedTensorType>(rankedTensor.getType());
RankedTensorType newT = RankedTensorType::Builder(t).insertDim(
reductionDimSize / splitFactor, insertSplitDimension);
SmallVector<Value> dims =
@@ -383,7 +383,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
combinerOps)) {
Value reindexedOutput = std::get<0>(it);
Value originalOutput = std::get<1>(it)->get();
- auto originalOutputType = originalOutput.getType().cast<RankedTensorType>();
+ auto originalOutputType = cast<RankedTensorType>(originalOutput.getType());
Operation *combinerOp = std::get<2>(it);
AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp
index c0355a14d366..f4556787668d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp
@@ -65,7 +65,7 @@ static FailureOr<tensor::ExtractSliceOp>
findHoistableMatchingExtractSlice(RewriterBase &rewriter,
tensor::InsertSliceOp insertSliceOp,
BlockArgument srcTensor) {
- assert(srcTensor.getType().isa<RankedTensorType>() && "not a ranked tensor");
+ assert(isa<RankedTensorType>(srcTensor.getType()) && "not a ranked tensor");
auto forOp = cast<scf::ForOp>(srcTensor.getOwner()->getParentOp());
@@ -92,7 +92,7 @@ findHoistableMatchingExtractSlice(RewriterBase &rewriter,
// Skip insert_slice whose vector is defined within the loop: we need to
// hoist that definition first otherwise dominance violations trigger.
- if (!extractSliceOp.getSource().isa<BlockArgument>() &&
+ if (!isa<BlockArgument>(extractSliceOp.getSource()) &&
!forOp.isDefinedOutsideOfLoop(extractSliceOp.getSource())) {
LLVM_DEBUG(DBGS() << "------transfer_read vector is loop-dependent\n");
continue;
@@ -119,7 +119,7 @@ static FailureOr<vector::TransferReadOp>
findHoistableMatchingTransferRead(RewriterBase &rewriter,
vector::TransferWriteOp transferWriteOp,
BlockArgument srcTensor) {
- if (!srcTensor.getType().isa<RankedTensorType>())
+ if (!isa<RankedTensorType>(srcTensor.getType()))
return failure();
auto forOp = cast<scf::ForOp>(srcTensor.getOwner()->getParentOp());
@@ -152,7 +152,7 @@ findHoistableMatchingTransferRead(RewriterBase &rewriter,
// transfer_read may be of a vector that is defined within the loop: we
// traverse it by virtue of bypassing disjoint subset operations rooted at
// a bbArg and yielding a matching yield.
- if (!read.getSource().isa<BlockArgument>() &&
+ if (!isa<BlockArgument>(read.getSource()) &&
!forOp.isDefinedOutsideOfLoop(read.getSource())) {
LLVM_DEBUG(DBGS() << "------transfer_read vector appears loop "
"dependent but will be tested for disjointness as "
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 1ff11665b402..57798fc78ea4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -49,7 +49,7 @@ static bool isZero(OpFoldResult v) {
if (!v)
return false;
if (auto attr = v.dyn_cast<Attribute>()) {
- IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
+ IntegerAttr intAttr = dyn_cast<IntegerAttr>(attr);
return intAttr && intAttr.getValue().isZero();
}
if (auto cst = v.get<Value>().getDefiningOp<arith::ConstantIndexOp>())
@@ -105,7 +105,7 @@ void mlir::linalg::transformIndexOps(
static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
OpFoldResult value) {
if (auto attr = value.dyn_cast<Attribute>()) {
- assert(attr.cast<IntegerAttr>().getValue().isStrictlyPositive() &&
+ assert(cast<IntegerAttr>(attr).getValue().isStrictlyPositive() &&
"expected strictly positive tile size and divisor");
return;
}
@@ -587,8 +587,8 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
SmallVector<Operation *, 8> loops;
loops.reserve(ivs.size());
for (auto iv : ivs) {
- if (iv.isa<BlockArgument>()) {
- loops.push_back(iv.cast<BlockArgument>().getOwner()->getParentOp());
+ if (isa<BlockArgument>(iv)) {
+ loops.push_back(cast<BlockArgument>(iv).getOwner()->getParentOp());
assert(loops.back() && "no owner found for induction variable!");
} else {
// TODO: Instead of doing this, try to recover the ops used instead of the
@@ -712,7 +712,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
outOffsets[reductionDim] = forallOp.getInductionVars().front();
// TODO: use SubsetExtractOpInterface once it is available.
tiledDpsInitOperands.push_back(b.create<tensor::ExtractSliceOp>(
- loc, initOperand->get().getType().cast<RankedTensorType>(),
+ loc, cast<RankedTensorType>(initOperand->get().getType()),
destBbArgs[destNum], outOffsets, sizes, strides));
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 1c3745f66cbf..36f13fa64dcc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -365,8 +365,7 @@ struct LinalgOpPartialReductionInterface
// Then create a new reduction that only reduce the newly added dimension
// from the previous op.
- int64_t intermRank =
- partialReduce[0].getType().cast<ShapedType>().getRank();
+ int64_t intermRank = cast<ShapedType>(partialReduce[0].getType()).getRank();
AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
SmallVector<utils::IteratorType> reductionIteratorTypes;
SmallVector<AffineExpr> exprs;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index a9e8ac0bbabb..230089582f25 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -89,7 +89,7 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
// Follow the use-def chain if `currOpOperand` is defined by a LinalgOp.
OpOperand *currOpOperand = opOperand;
while (auto linalgOp = currOpOperand->get().getDefiningOp<LinalgOp>()) {
- OpResult result = currOpOperand->get().cast<OpResult>();
+ OpResult result = cast<OpResult>(currOpOperand->get());
currOpOperand = linalgOp.getDpsInitOperand(result.getResultNumber());
}
@@ -133,7 +133,7 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
// If the size is an attribute add it directly to `paddedShape`.
if (en.value().is<Attribute>()) {
paddedShape[shapeIdx++] =
- en.value().get<Attribute>().dyn_cast<IntegerAttr>().getInt();
+ dyn_cast<IntegerAttr>(en.value().get<Attribute>()).getInt();
LLVM_DEBUG(
DBGS() << "------dim is an attr, add it to padded shape, SKIP\n");
continue;
@@ -232,7 +232,7 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
Value paddedResult = en.value();
int64_t resultNumber = en.index();
- int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank();
+ int64_t rank = cast<RankedTensorType>(paddedResult.getType()).getRank();
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
@@ -476,7 +476,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
tensor::PackOp packOp) {
// 1. Filter out NYI cases.
auto packedTensorType =
- packOp->getResultTypes().front().cast<RankedTensorType>();
+ cast<RankedTensorType>(packOp->getResultTypes().front());
if (llvm::any_of(packOp.getStaticInnerTiles(),
[](int64_t size) { return ShapedType::isDynamic(size); })) {
return rewriter.notifyMatchFailure(
@@ -639,7 +639,7 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
int64_t packedRank = packedTensorType.getRank();
OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
- auto destTensorType = unPackOp.getDest().getType().cast<RankedTensorType>();
+ auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
if (unPackOp.isLikeUnPad()) {
// This unpack is just a plain unpad.
// Just extract the slice from the higher ranked tensor.
@@ -889,7 +889,7 @@ static LinalgOp transposeOneLinalgOperandAndReplace(
// Sanity check of the expected transposed tensor type.
auto tensorType = permuteShape(
- opOperand.get().getType().cast<RankedTensorType>(), permutation);
+ cast<RankedTensorType>(opOperand.get().getType()), permutation);
(void)tensorType;
assert(tensorType == transposedValue.getType() &&
"expected tensor type mismatch");
@@ -1050,8 +1050,8 @@ LogicalResult
PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp,
PatternRewriter &rewriter) const {
- auto inputShapedType = padOp.getSource().getType().cast<ShapedType>();
- auto resultShapedType = padOp.getResult().getType().cast<ShapedType>();
+ auto inputShapedType = cast<ShapedType>(padOp.getSource().getType());
+ auto resultShapedType = cast<ShapedType>(padOp.getResult().getType());
// Bail on non-static shapes.
if (!inputShapedType.hasStaticShape())
@@ -1068,7 +1068,7 @@ PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp,
Operation *definingOp = padValue.getDefiningOp();
if (definingOp && definingOp->getBlock() == &block)
return failure();
- if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
+ if (!definingOp && cast<BlockArgument>(padValue).getOwner() == &block)
return failure();
// Create tensor with the padded shape
@@ -1134,7 +1134,7 @@ GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
return val;
return rewriter
.create<arith::ConstantIndexOp>(
- padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt())
+ padOp.getLoc(), cast<IntegerAttr>(ofr.get<Attribute>()).getInt())
.getResult();
};
@@ -1514,9 +1514,9 @@ FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
Value kernel = convOp.getInputs().back();
Value output = convOp.getOutputs().front();
- auto inputType = input.getType().dyn_cast<RankedTensorType>();
- auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
- auto outputType = output.getType().dyn_cast<RankedTensorType>();
+ auto inputType = dyn_cast<RankedTensorType>(input.getType());
+ auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
+ auto outputType = dyn_cast<RankedTensorType>(output.getType());
auto kernelShape = kernelType.getShape();
auto outputShape = outputType.getShape();
@@ -1638,9 +1638,9 @@ DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
Value kernel = convOp.getInputs().back();
Value output = convOp.getOutputs().front();
- auto inputType = input.getType().dyn_cast<RankedTensorType>();
- auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
- auto outputType = output.getType().dyn_cast<RankedTensorType>();
+ auto inputType = dyn_cast<RankedTensorType>(input.getType());
+ auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
+ auto outputType = dyn_cast<RankedTensorType>(output.getType());
auto kernelShape = kernelType.getShape();
auto outputShape = outputType.getShape();
@@ -1706,9 +1706,9 @@ DownscaleConv2DOp::returningMatchAndRewrite(Conv2DOp convOp,
Value kernel = convOp.getInputs().back();
Value output = convOp.getOutputs().front();
- auto inputType = input.getType().dyn_cast<RankedTensorType>();
- auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
- auto outputType = output.getType().dyn_cast<RankedTensorType>();
+ auto inputType = dyn_cast<RankedTensorType>(input.getType());
+ auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
+ auto outputType = dyn_cast<RankedTensorType>(output.getType());
auto kernelShape = kernelType.getShape();
auto outputShape = outputType.getShape();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 56b4516452a1..2236d1bff111 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -563,7 +563,7 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
loc, value, outputOperand->get(), indices, writeMap);
} else {
// 0-d case is still special: do not invert the reindexing writeMap.
- if (!value.getType().isa<VectorType>())
+ if (!isa<VectorType>(value.getType()))
value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
assert(value.getType() == vectorType && "incorrect type");
write = rewriter.create<vector::TransferWriteOp>(
@@ -864,7 +864,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
targetShape.back() == 1)
return VectorMemoryAccessKind::Gather;
- auto inputShape = extractOp.getTensor().getType().cast<ShapedType>();
+ auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
// 2. Assume that it's a gather load when reading _from_ a tensor for which
// the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
@@ -1024,8 +1024,8 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
const IRMapping &bvm) {
Value reduceVec = bvm.lookup(reduceValue);
Value outputVec = bvm.lookup(initialValue);
- auto reduceType = reduceVec.getType().dyn_cast<VectorType>();
- auto outputType = outputVec.getType().dyn_cast<VectorType>();
+ auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
+ auto outputType = dyn_cast<VectorType>(outputVec.getType());
// Reduce only if needed as the value may already have been reduce for
// contraction vectorization.
if (!reduceType ||
@@ -1082,7 +1082,7 @@ vectorizeOneOp(RewriterBase &rewriter, LinalgOp linalgOp, Operation *op,
// 4 . Check if the operation is a reduction.
SmallVector<std::pair<Value, Value>> reductionOperands;
for (Value operand : op->getOperands()) {
- auto blockArg = operand.dyn_cast<BlockArgument>();
+ auto blockArg = dyn_cast<BlockArgument>(operand);
if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
continue;
@@ -1107,7 +1107,7 @@ vectorizeOneOp(RewriterBase &rewriter, LinalgOp linalgOp, Operation *op,
// a. first get the first max ranked shape.
SmallVector<int64_t, 4> firstMaxRankedShape;
for (Value operand : op->getOperands()) {
- auto vt = bvm.lookup(operand).getType().dyn_cast<VectorType>();
+ auto vt = dyn_cast<VectorType>(bvm.lookup(operand).getType());
if (vt && firstMaxRankedShape.size() < vt.getShape().size())
firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end());
}
@@ -1230,7 +1230,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
// 3.c. Not all ops support 0-d vectors, extract the scalar for now.
// TODO: remove this.
- if (readValue.getType().cast<VectorType>().getRank() == 0)
+ if (cast<VectorType>(readValue.getType()).getRank() == 0)
readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
@@ -1528,8 +1528,8 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, LinalgOp linalgOp,
LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
memref::CopyOp copyOp) {
- auto srcType = copyOp.getSource().getType().cast<MemRefType>();
- auto dstType = copyOp.getTarget().getType().cast<MemRefType>();
+ auto srcType = cast<MemRefType>(copyOp.getSource().getType());
+ auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
return failure();
@@ -1549,7 +1549,7 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
Value readValue = rewriter.create<vector::TransferReadOp>(
loc, readType, copyOp.getSource(), indices,
rewriter.getMultiDimIdentityMap(srcType.getRank()));
- if (readValue.getType().cast<VectorType>().getRank() == 0) {
+ if (cast<VectorType>(readValue.getType()).getRank() == 0) {
readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
}
@@ -1566,7 +1566,7 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
/// Helper function that retrieves the value of an IntegerAttr.
static int64_t getIntFromAttr(Attribute attr) {
- return attr.cast<IntegerAttr>().getInt();
+ return cast<IntegerAttr>(attr).getInt();
}
/// Given an ArrayRef of OpFoldResults, return a vector of Values.
@@ -1836,8 +1836,8 @@ struct PadOpVectorizationWithTransferWritePattern
if (hasSameTensorSize(castOp.getSource(), afterTrimming))
return true;
- auto t1 = beforePadding.getType().dyn_cast<RankedTensorType>();
- auto t2 = afterTrimming.getType().dyn_cast<RankedTensorType>();
+ auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
+ auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
// Only RankedTensorType supported.
if (!t1 || !t2)
return false;
@@ -1946,7 +1946,7 @@ struct PadOpVectorizationWithInsertSlicePattern
if (!padValue)
return failure();
// Dynamic shapes not supported.
- if (!padOp.getResult().getType().cast<ShapedType>().hasStaticShape())
+ if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
return failure();
// Pad result not used as destination.
if (insertOp.getDest() == padOp.getResult())
@@ -2074,7 +2074,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
memref::CopyOp copyOp;
for (auto &u : subView.getUses()) {
if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
- assert(newCopyOp.getTarget().getType().isa<MemRefType>());
+ assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
if (newCopyOp.getTarget() != subView)
continue;
if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
@@ -2091,7 +2091,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
FillOp maybeFillOp;
for (auto &u : viewOrAlloc.getUses()) {
if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
- assert(newFillOp.output().getType().isa<MemRefType>());
+ assert(isa<MemRefType>(newFillOp.output().getType()));
if (newFillOp.output() != viewOrAlloc)
continue;
if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
@@ -2162,7 +2162,7 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
return rewriter.notifyMatchFailure(xferOp, "no copy found");
// `out` is the subview copied into that we replace.
- assert(copyOp.getTarget().getType().isa<MemRefType>());
+ assert(isa<MemRefType>(copyOp.getTarget().getType()));
Value out = copyOp.getTarget();
// Forward vector.transfer into copy.
@@ -2204,7 +2204,7 @@ static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
namespace {
bool isCastOfBlockArgument(Operation *op) {
return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
- op->getOperand(0).isa<BlockArgument>();
+ isa<BlockArgument>(op->getOperand(0));
}
bool isSupportedPoolKind(vector::CombiningKind kind) {
@@ -2268,9 +2268,9 @@ struct Conv1DGenerator
lhsShaped = linalgOp.getDpsInputOperand(0)->get();
rhsShaped = linalgOp.getDpsInputOperand(1)->get();
resShaped = linalgOp.getDpsInitOperand(0)->get();
- lhsShapedType = lhsShaped.getType().dyn_cast<ShapedType>();
- rhsShapedType = rhsShaped.getType().dyn_cast<ShapedType>();
- resShapedType = resShaped.getType().dyn_cast<ShapedType>();
+ lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
+ rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
+ resShapedType = dyn_cast<ShapedType>(resShaped.getType());
if (!lhsShapedType || !rhsShapedType || !resShapedType)
return;
// (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
@@ -2717,8 +2717,8 @@ struct Conv1DGenerator
/// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc
Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
Value lhs, Value rhs, Value res) {
- auto rhsTy = rhs.getType().cast<ShapedType>();
- auto resTy = res.getType().cast<ShapedType>();
+ auto rhsTy = cast<ShapedType>(rhs.getType());
+ auto resTy = cast<ShapedType>(res.getType());
// TODO(suderman): Change this to use a vector.ima intrinsic.
lhs = promote(rewriter, loc, lhs, resTy);
@@ -2730,7 +2730,7 @@ struct Conv1DGenerator
if (!lhs || !rhs)
return nullptr;
- if (resTy.getElementType().isa<FloatType>())
+ if (isa<FloatType>(resTy.getElementType()))
return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
@@ -2863,15 +2863,14 @@ private:
// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
// must be block arguments or extension of block arguments.
bool setOperKind(Operation *reduceOp) {
- int numBlockArguments =
- llvm::count_if(reduceOp->getOperands(),
- [](Value v) { return v.isa<BlockArgument>(); });
+ int numBlockArguments = llvm::count_if(
+ reduceOp->getOperands(), [](Value v) { return isa<BlockArgument>(v); });
switch (numBlockArguments) {
case 1: {
// Will be convolution if feeder is a MulOp.
// Otherwise, if it can be pooling.
auto feedValIt = llvm::find_if(reduceOp->getOperands(), [](Value v) {
- return !v.isa<BlockArgument>();
+ return !isa<BlockArgument>(v);
});
Operation *feedOp = (*feedValIt).getDefiningOp();
if (isCastOfBlockArgument(feedOp)) {
@@ -2880,7 +2879,7 @@ private:
poolExtOp = feedOp->getName().getIdentifier();
} else if (!(isa<arith::MulIOp, arith::MulFOp>(feedOp) &&
llvm::all_of(feedOp->getOperands(), [](Value v) {
- if (v.isa<BlockArgument>())
+ if (isa<BlockArgument>(v))
return true;
if (Operation *op = v.getDefiningOp())
return isCastOfBlockArgument(op);
diff --git a/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp
index 12b55ef2e660..f7376c0d9602 100644
--- a/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp
@@ -43,16 +43,16 @@
namespace mlir {
namespace linalg {
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim) {
- if (val.getType().isa<UnrankedMemRefType, MemRefType>())
+ if (isa<UnrankedMemRefType, MemRefType>(val.getType()))
return b.createOrFold<memref::DimOp>(loc, val, dim);
- if (val.getType().isa<UnrankedTensorType, RankedTensorType>())
+ if (isa<UnrankedTensorType, RankedTensorType>(val.getType()))
return b.createOrFold<tensor::DimOp>(loc, val, dim);
llvm_unreachable("Expected MemRefType or TensorType");
}
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val,
int64_t dim) {
- auto shapedType = val.getType().cast<ShapedType>();
+ auto shapedType = cast<ShapedType>(val.getType());
if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
return createOrFoldDimOp(b, loc, val, dim);
return b.getIndexAttr(shapedType.getDimSize(dim));
@@ -60,7 +60,7 @@ OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val,
SmallVector<Value> createDynamicDimensions(OpBuilder &b, Location loc,
Value val) {
- auto shapedType = val.getType().cast<ShapedType>();
+ auto shapedType = cast<ShapedType>(val.getType());
assert(shapedType.hasRank() && "`val` must have a static rank");
SmallVector<Value> res;
res.reserve(shapedType.getRank());
@@ -73,7 +73,7 @@ SmallVector<Value> createDynamicDimensions(OpBuilder &b, Location loc,
SmallVector<OpFoldResult> getMixedDimensions(OpBuilder &b, Location loc,
Value val) {
- auto shapedType = val.getType().cast<ShapedType>();
+ auto shapedType = cast<ShapedType>(val.getType());
assert(shapedType.hasRank() && "`val` must have a static rank");
SmallVector<Value> dynamicDims = createDynamicDimensions(b, loc, val);
return getMixedValues(shapedType.getShape(), dynamicDims, b);
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 5e3413accf7c..ef31668ed25b 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -281,7 +281,7 @@ Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
auto linalgOp = current.getDefiningOp<LinalgOp>();
if (!linalgOp)
break;
- OpResult opResult = current.cast<OpResult>();
+ OpResult opResult = cast<OpResult>(current);
current = linalgOp.getDpsInitOperand(opResult.getResultNumber())->get();
}
auto padOp = current ? current.getDefiningOp<tensor::PadOp>() : nullptr;
@@ -331,7 +331,7 @@ Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor,
Value outputTensor,
ArrayRef<int64_t> transposeVector) {
- auto resultTensorType = outputTensor.getType().cast<RankedTensorType>();
+ auto resultTensorType = cast<RankedTensorType>(outputTensor.getType());
Type elementType = resultTensorType.getElementType();
assert(isPermutationVector(transposeVector) &&
@@ -366,9 +366,9 @@ GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor,
}
GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) {
- auto memrefTypeTo = to.getType().cast<MemRefType>();
+ auto memrefTypeTo = cast<MemRefType>(to.getType());
#ifndef NDEBUG
- auto memrefTypeFrom = from.getType().cast<MemRefType>();
+ auto memrefTypeFrom = cast<MemRefType>(from.getType());
assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() &&
"`from` and `to` memref must have the same rank");
#endif // NDEBUG
@@ -650,7 +650,7 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
static Value materializeTiledShape(OpBuilder &builder, Location loc,
Value valueToTile,
const SliceParameters &sliceParams) {
- auto shapedType = valueToTile.getType().dyn_cast<ShapedType>();
+ auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
.Case([&](MemRefType) {
return builder.create<memref::SubViewOp>(
@@ -685,7 +685,7 @@ computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
ArrayRef<OpFoldResult> subShapeSizes,
bool omitPartialTileCheck) {
- auto shapedType = valueToTile.getType().dyn_cast<ShapedType>();
+ auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
assert(shapedType && "only shaped types can be tiled");
ArrayRef<int64_t> shape = shapedType.getShape();
int64_t rank = shapedType.getRank();
@@ -889,7 +889,7 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
// subdomains explicit.
Type operandType = opOperand.get().getType();
- if (!isTiled(map, tileSizes) && !(operandType.isa<RankedTensorType>() &&
+ if (!isTiled(map, tileSizes) && !(isa<RankedTensorType>(operandType) &&
linalgOp.isDpsInit(&opOperand))) {
allSliceParams.push_back(std::nullopt);
LLVM_DEBUG(llvm::dbgs()
@@ -971,7 +971,7 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
auto size = it.value();
curr.push_back(dim);
auto attr = size.dyn_cast<Attribute>();
- if (attr && attr.cast<IntegerAttr>().getInt() == 1)
+ if (attr && cast<IntegerAttr>(attr).getInt() == 1)
continue;
reassociation.emplace_back(ReassociationIndices{});
std::swap(reassociation.back(), curr);
@@ -989,7 +989,7 @@ std::optional<TypedAttr> getNeutralElement(Operation *op) {
// Builder only used as helper for attribute creation.
OpBuilder b(op->getContext());
Type resultType = op->getResult(0).getType();
- if (auto floatType = resultType.dyn_cast<FloatType>()) {
+ if (auto floatType = dyn_cast<FloatType>(resultType)) {
const llvm::fltSemantics &semantic = floatType.getFloatSemantics();
if (isa<arith::AddFOp>(op))
return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic));
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
index c5e008e52047..dcace489673f 100644
--- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp
@@ -64,7 +64,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
// Maybe broadcasts scalar value into vector type compatible with `op`.
auto bcast = [&](Value value) -> Value {
- if (auto vec = op.getType().dyn_cast<VectorType>())
+ if (auto vec = dyn_cast<VectorType>(op.getType()))
return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value);
return value;
};
@@ -167,7 +167,7 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
// Maybe broadcasts scalar value into vector type compatible with `op`.
auto bcast = [&loc, &op, &rewriter](Value value) -> Value {
- if (auto vec = op.getType().template dyn_cast<VectorType>())
+ if (auto vec = dyn_cast<VectorType>(op.getType()))
return rewriter.create<vector::BroadcastOp>(loc, vec, value);
return value;
};
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 6d286a31290e..a3efc6ef41a9 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -40,7 +40,7 @@ using namespace mlir::vector;
// Returns vector shape if the type is a vector. Returns an empty shape if it is
// not a vector.
static ArrayRef<int64_t> vectorShape(Type type) {
- auto vectorType = type.dyn_cast<VectorType>();
+ auto vectorType = dyn_cast<VectorType>(type);
return vectorType ? vectorType.getShape() : ArrayRef<int64_t>();
}
@@ -54,14 +54,14 @@ static ArrayRef<int64_t> vectorShape(Value value) {
// Broadcasts scalar type into vector type (iff shape is non-scalar).
static Type broadcast(Type type, ArrayRef<int64_t> shape) {
- assert(!type.isa<VectorType>() && "must be scalar type");
+ assert(!isa<VectorType>(type) && "must be scalar type");
return !shape.empty() ? VectorType::get(shape, type) : type;
}
// Broadcasts scalar value into vector (iff shape is non-scalar).
static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
ArrayRef<int64_t> shape) {
- assert(!value.getType().isa<VectorType>() && "must be scalar value");
+ assert(!isa<VectorType>(value.getType()) && "must be scalar value");
auto type = broadcast(value.getType(), shape);
return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value;
}
@@ -92,7 +92,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
assert(!operands.empty() && "operands must be not empty");
assert(vectorWidth > 0 && "vector width must be larger than 0");
- VectorType inputType = operands[0].getType().cast<VectorType>();
+ VectorType inputType = cast<VectorType>(operands[0].getType());
ArrayRef<int64_t> inputShape = inputType.getShape();
// If input shape matches target vector width, we can just call the
@@ -118,7 +118,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
for (unsigned i = 0; i < operands.size(); ++i) {
auto operand = operands[i];
- auto eltType = operand.getType().cast<VectorType>().getElementType();
+ auto eltType = cast<VectorType>(operand.getType()).getElementType();
auto expandedType = VectorType::get(expandedShape, eltType);
expandedOperands[i] =
builder.create<vector::ShapeCastOp>(expandedType, operand);
@@ -145,7 +145,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
}
// Stitch results together into one large vector.
- Type resultEltType = results[0].getType().cast<VectorType>().getElementType();
+ Type resultEltType = cast<VectorType>(results[0].getType()).getElementType();
Type resultExpandedType = VectorType::get(expandedShape, resultEltType);
Value result = builder.create<arith::ConstantOp>(
resultExpandedType, builder.getZeroAttr(resultExpandedType));
@@ -318,9 +318,9 @@ LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) {
// Create F32 equivalent type.
Type newType;
- if (auto shaped = origType.dyn_cast<ShapedType>()) {
+ if (auto shaped = dyn_cast<ShapedType>(origType)) {
newType = shaped.clone(rewriter.getF32Type());
- } else if (origType.isa<FloatType>()) {
+ } else if (isa<FloatType>(origType)) {
newType = rewriter.getF32Type();
} else {
return rewriter.notifyMatchFailure(op,
diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
index 7f702e197854..ae2472db4f86 100644
--- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
+++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
@@ -69,7 +69,7 @@ DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply(
results.push_back(*newBuffer);
}
- transformResults.set(getResult().cast<OpResult>(), results);
+ transformResults.set(cast<OpResult>(getResult()), results);
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
index 369f22521895..9b1d85b29027 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
@@ -57,7 +57,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
// always 1.
if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) {
Attribute attr = valueOrAttr.dyn_cast<Attribute>();
- return attr && attr.cast<IntegerAttr>().getInt() == 1;
+ return attr && cast<IntegerAttr>(attr).getInt() == 1;
})) {
strides = SmallVector<OpFoldResult>(sourceOp.getMixedStrides().size(),
rewriter.getI64IntegerAttr(1));
@@ -93,8 +93,8 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
// If both offsets are static we can simply calculate the combined
// offset statically.
offsets.push_back(rewriter.getI64IntegerAttr(
- opOffsetAttr.cast<IntegerAttr>().getInt() +
- sourceOffsetAttr.cast<IntegerAttr>().getInt()));
+ cast<IntegerAttr>(opOffsetAttr).getInt() +
+ cast<IntegerAttr>(sourceOffsetAttr).getInt()));
} else {
// When either offset is dynamic, we must emit an additional affine
// transformation to add the two offsets together dynamically.
@@ -102,7 +102,7 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
SmallVector<Value> affineApplyOperands;
for (auto valueOrAttr : {opOffset, sourceOffset}) {
if (auto attr = valueOrAttr.dyn_cast<Attribute>()) {
- expr = expr + attr.cast<IntegerAttr>().getInt();
+ expr = expr + cast<IntegerAttr>(attr).getInt();
} else {
expr =
expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size());
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
index 6202b5730c21..57f0141c95dc 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
@@ -149,7 +149,7 @@ void memref::populateMemRefWideIntEmulationConversions(
arith::WideIntEmulationConverter &typeConverter) {
typeConverter.addConversion(
[&typeConverter](MemRefType ty) -> std::optional<Type> {
- auto intTy = ty.getElementType().dyn_cast<IntegerType>();
+ auto intTy = dyn_cast<IntegerType>(ty.getElementType());
if (!intTy)
return ty;
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
index 38fb11348f28..8a276ebbff6a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
@@ -89,11 +89,11 @@ public:
LogicalResult matchAndRewrite(memref::ReshapeOp op,
PatternRewriter &rewriter) const final {
- auto shapeType = op.getShape().getType().cast<MemRefType>();
+ auto shapeType = cast<MemRefType>(op.getShape().getType());
if (!shapeType.hasStaticShape())
return failure();
- int64_t rank = shapeType.cast<MemRefType>().getDimSize(0);
+ int64_t rank = cast<MemRefType>(shapeType).getDimSize(0);
SmallVector<OpFoldResult, 4> sizes, strides;
sizes.resize(rank);
strides.resize(rank);
@@ -106,7 +106,7 @@ public:
if (op.getType().isDynamicDim(i)) {
Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
size = rewriter.create<memref::LoadOp>(loc, op.getShape(), index);
- if (!size.getType().isa<IndexType>())
+ if (!isa<IndexType>(size.getType()))
size = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), size);
sizes[i] = size;
@@ -141,7 +141,7 @@ struct ExpandOpsPass : public memref::impl::ExpandOpsBase<ExpandOpsPass> {
op.getKind() != arith::AtomicRMWKind::minf;
});
target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
- return !op.getShape().getType().cast<MemRefType>().hasStaticShape();
+ return !cast<MemRefType>(op.getShape().getType()).hasStaticShape();
});
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index ea372bffbc0b..ff2c4107ee46 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -62,7 +62,7 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter,
// Build a plain extract_strided_metadata(memref) from subview(memref).
Location origLoc = subview.getLoc();
Value source = subview.getSource();
- auto sourceType = source.getType().cast<MemRefType>();
+ auto sourceType = cast<MemRefType>(source.getType());
unsigned sourceRank = sourceType.getRank();
auto newExtractStridedMetadata =
@@ -115,7 +115,7 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter,
// The final result is <baseBuffer, offset, sizes, strides>.
// Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all
// the values.
- auto subType = subview.getType().cast<MemRefType>();
+ auto subType = cast<MemRefType>(subview.getType());
unsigned subRank = subType.getRank();
// The sizes of the final type are defined directly by the input sizes of
@@ -338,7 +338,7 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
// Collect the statically known information about the original stride.
Value source = expandShape.getSrc();
- auto sourceType = source.getType().cast<MemRefType>();
+ auto sourceType = cast<MemRefType>(source.getType());
auto [strides, offset] = getStridesAndOffset(sourceType);
OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
@@ -358,10 +358,9 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
AffineExpr s0 = builder.getAffineSymbolExpr(0);
AffineExpr s1 = builder.getAffineSymbolExpr(1);
for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
- int64_t baseExpandedStride = expandedStrides[doneStrideIdx]
- .get<Attribute>()
- .cast<IntegerAttr>()
- .getInt();
+ int64_t baseExpandedStride =
+ cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>())
+ .getInt();
expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
builder, expandShape.getLoc(),
(s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
@@ -372,10 +371,9 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
// Now apply the origStride to the remaining dimensions.
AffineExpr s0 = builder.getAffineSymbolExpr(0);
for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
- int64_t baseExpandedStride = expandedStrides[doneStrideIdx]
- .get<Attribute>()
- .cast<IntegerAttr>()
- .getInt();
+ int64_t baseExpandedStride =
+ cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>())
+ .getInt();
expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
}
@@ -445,7 +443,7 @@ getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
// Build the affine expr of the product of the original sizes involved in that
// group.
Value source = collapseShape.getSrc();
- auto sourceType = source.getType().cast<MemRefType>();
+ auto sourceType = cast<MemRefType>(source.getType());
SmallVector<int64_t, 2> reassocGroup =
collapseShape.getReassociationIndices()[groupId];
@@ -479,7 +477,7 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
"Reassociation group should have at least one dimension");
Value source = collapseShape.getSrc();
- auto sourceType = source.getType().cast<MemRefType>();
+ auto sourceType = cast<MemRefType>(source.getType());
auto [strides, offset] = getStridesAndOffset(sourceType);
@@ -562,7 +560,7 @@ public:
// extract_strided_metadata(reassociative_reshape_like(memref)).
Location origLoc = reshape.getLoc();
Value source = reshape.getSrc();
- auto sourceType = source.getType().cast<MemRefType>();
+ auto sourceType = cast<MemRefType>(source.getType());
unsigned sourceRank = sourceType.getRank();
auto newExtractStridedMetadata =
@@ -650,8 +648,7 @@ public:
if (!allocLikeOp)
return failure();
- auto memRefType =
- allocLikeOp.getResult().getType().template cast<MemRefType>();
+ auto memRefType = cast<MemRefType>(allocLikeOp.getResult().getType());
if (!memRefType.getLayout().isIdentity())
return rewriter.notifyMatchFailure(
allocLikeOp, "alloc-like operations should have been normalized");
@@ -688,7 +685,7 @@ public:
SmallVector<Value> results;
results.reserve(rank * 2 + 2);
- auto baseBufferType = op.getBaseBuffer().getType().cast<MemRefType>();
+ auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
int64_t offset = 0;
if (allocLikeOp.getType() == baseBufferType)
results.push_back(allocLikeOp);
@@ -737,7 +734,7 @@ public:
if (!getGlobalOp)
return failure();
- auto memRefType = getGlobalOp.getResult().getType().cast<MemRefType>();
+ auto memRefType = cast<MemRefType>(getGlobalOp.getResult().getType());
if (!memRefType.getLayout().isIdentity()) {
return rewriter.notifyMatchFailure(
getGlobalOp,
@@ -759,7 +756,7 @@ public:
SmallVector<Value> results;
results.reserve(rank * 2 + 2);
- auto baseBufferType = op.getBaseBuffer().getType().cast<MemRefType>();
+ auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
int64_t offset = 0;
if (getGlobalOp.getType() == baseBufferType)
results.push_back(getGlobalOp);
@@ -838,8 +835,7 @@ class ExtractStridedMetadataOpReinterpretCastFolder
return rewriter.notifyMatchFailure(
reinterpretCastOp, "reinterpret_cast source's type is incompatible");
- auto memrefType =
- reinterpretCastOp.getResult().getType().cast<MemRefType>();
+ auto memrefType = cast<MemRefType>(reinterpretCastOp.getResult().getType());
unsigned rank = memrefType.getRank();
SmallVector<OpFoldResult> results;
results.resize_for_overwrite(rank * 2 + 2);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
index 5141b5f33cfa..05ba6a3f3870 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
@@ -120,7 +120,7 @@ template <typename TransferLikeOp>
static FailureOr<Value>
getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) {
Value src = transferLikeOp.getSource();
- if (src.getType().isa<MemRefType>())
+ if (isa<MemRefType>(src.getType()))
return src;
return failure();
}
@@ -240,7 +240,7 @@ struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> {
return rewriter.notifyMatchFailure(loadStoreLikeOp,
"source is not a memref");
Value srcMemRef = *failureOrSrcMemRef;
- auto ldStTy = srcMemRef.getType().cast<MemRefType>();
+ auto ldStTy = cast<MemRefType>(srcMemRef.getType());
unsigned loadStoreRank = ldStTy.getRank();
// Don't waste compile time if there is nothing to rewrite.
if (loadStoreRank == 0)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 72675b03abf6..2c30e98dd107 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -148,7 +148,7 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
if (collapseShapeOp.getReassociationIndices().empty()) {
auto zeroAffineMap = rewriter.getConstantAffineMap(0);
int64_t srcRank =
- collapseShapeOp.getViewSource().getType().cast<MemRefType>().getRank();
+ cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
for (int64_t i = 0; i < srcRank; i++) {
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
rewriter, loc, zeroAffineMap, dynamicIndices);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
index aa1d27dc863e..68b72eff8c97 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
@@ -71,11 +71,9 @@ propagateSubViewOp(RewriterBase &rewriter,
UnrealizedConversionCastOp conversionOp, SubViewOp op) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(op);
- auto newResultType =
- SubViewOp::inferRankReducedResultType(
- op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(),
- op.getMixedSizes(), op.getMixedStrides())
- .cast<MemRefType>();
+ auto newResultType = cast<MemRefType>(SubViewOp::inferRankReducedResultType(
+ op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(),
+ op.getMixedSizes(), op.getMixedStrides()));
Value newSubview = rewriter.create<SubViewOp>(
op.getLoc(), newResultType, conversionOp.getOperand(0),
op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides());
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
index ee1adcce80e5..eb1df2a87b99 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
@@ -61,11 +61,11 @@ static void replaceUsesAndPropagateType(RewriterBase &rewriter,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(subviewUse);
Type newType = memref::SubViewOp::inferRankReducedResultType(
- subviewUse.getType().getShape(), val.getType().cast<MemRefType>(),
+ subviewUse.getType().getShape(), cast<MemRefType>(val.getType()),
subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(),
subviewUse.getStaticStrides());
Value newSubview = rewriter.create<memref::SubViewOp>(
- subviewUse->getLoc(), newType.cast<MemRefType>(), val,
+ subviewUse->getLoc(), cast<MemRefType>(newType), val,
subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
subviewUse.getMixedStrides());
@@ -209,9 +209,9 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
for (int64_t i = 0, e = originalShape.size(); i != e; ++i)
sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]);
// Strides is [1, 1 ... 1 ].
- auto dstMemref = memref::SubViewOp::inferRankReducedResultType(
- originalShape, mbMemRefType, offsets, sizes, strides)
- .cast<MemRefType>();
+ auto dstMemref =
+ cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
+ originalShape, mbMemRefType, offsets, sizes, strides));
Value subview = rewriter.create<memref::SubViewOp>(loc, dstMemref, mbAlloc,
offsets, sizes, strides);
LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n");
diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
index c252433d16fa..aa21497fad8f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
@@ -180,7 +180,7 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
llvm::seq<unsigned>(0, callOp.getNumResults())) {
Value oldMemRef = callOp.getResult(resIndex);
if (auto oldMemRefType =
- oldMemRef.getType().dyn_cast<MemRefType>())
+ dyn_cast<MemRefType>(oldMemRef.getType()))
if (!oldMemRefType.getLayout().isIdentity() &&
!isMemRefNormalizable(oldMemRef.getUsers()))
return WalkResult::interrupt();
@@ -192,7 +192,7 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
BlockArgument oldMemRef = funcOp.getArgument(argIndex);
- if (auto oldMemRefType = oldMemRef.getType().dyn_cast<MemRefType>())
+ if (auto oldMemRefType = dyn_cast<MemRefType>(oldMemRef.getType()))
if (!oldMemRefType.getLayout().isIdentity() &&
!isMemRefNormalizable(oldMemRef.getUsers()))
return false;
@@ -226,7 +226,7 @@ void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp,
funcOp.walk([&](func::ReturnOp returnOp) {
for (const auto &operandEn : llvm::enumerate(returnOp.getOperands())) {
Type opType = operandEn.value().getType();
- MemRefType memrefType = opType.dyn_cast<MemRefType>();
+ MemRefType memrefType = dyn_cast<MemRefType>(opType);
// If type is not memref or if the memref type is same as that in
// function's return signature then no update is required.
if (!memrefType || memrefType == resultTypes[operandEn.index()])
@@ -284,7 +284,7 @@ void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp,
if (oldResult.getType() == newResult.getType())
continue;
AffineMap layoutMap =
- oldResult.getType().cast<MemRefType>().getLayout().getAffineMap();
+ cast<MemRefType>(oldResult.getType()).getLayout().getAffineMap();
if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult,
/*extraIndices=*/{},
/*indexRemap=*/layoutMap,
@@ -358,7 +358,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
for (unsigned argIndex :
llvm::seq<unsigned>(0, functionType.getNumInputs())) {
Type argType = functionType.getInput(argIndex);
- MemRefType memrefType = argType.dyn_cast<MemRefType>();
+ MemRefType memrefType = dyn_cast<MemRefType>(argType);
// Check whether argument is of MemRef type. Any other argument type can
// simply be part of the final function signature.
if (!memrefType) {
@@ -422,11 +422,11 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
// Replace all uses of the old memrefs.
Value oldMemRef = op->getResult(resIndex);
Value newMemRef = newOp->getResult(resIndex);
- MemRefType oldMemRefType = oldMemRef.getType().dyn_cast<MemRefType>();
+ MemRefType oldMemRefType = dyn_cast<MemRefType>(oldMemRef.getType());
// Check whether the operation result is MemRef type.
if (!oldMemRefType)
continue;
- MemRefType newMemRefType = newMemRef.getType().cast<MemRefType>();
+ MemRefType newMemRefType = cast<MemRefType>(newMemRef.getType());
if (oldMemRefType == newMemRefType)
continue;
// TODO: Assume single layout map. Multiple maps not supported.
@@ -466,7 +466,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
for (unsigned resIndex :
llvm::seq<unsigned>(0, functionType.getNumResults())) {
Type resType = functionType.getResult(resIndex);
- MemRefType memrefType = resType.dyn_cast<MemRefType>();
+ MemRefType memrefType = dyn_cast<MemRefType>(resType);
// Check whether result is of MemRef type. Any other argument type can
// simply be part of the final function signature.
if (!memrefType) {
@@ -507,7 +507,7 @@ Operation *NormalizeMemRefs::createOpResultsNormalized(func::FuncOp funcOp,
bool resultTypeNormalized = false;
for (unsigned resIndex : llvm::seq<unsigned>(0, oldOp->getNumResults())) {
auto resultType = oldOp->getResult(resIndex).getType();
- MemRefType memrefType = resultType.dyn_cast<MemRefType>();
+ MemRefType memrefType = dyn_cast<MemRefType>(resultType);
// Check whether the operation result is MemRef type.
if (!memrefType) {
resultTypes.push_back(resultType);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 8c544bbd9fb0..526c1c6e198f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -40,7 +40,7 @@ struct DimOfShapedTypeOpInterface : public OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy dimOp,
PatternRewriter &rewriter) const override {
- OpResult dimValue = dimOp.getSource().template dyn_cast<OpResult>();
+ OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
if (!dimValue)
return failure();
auto shapedTypeOp =
@@ -61,8 +61,8 @@ struct DimOfShapedTypeOpInterface : public OpRewritePattern<OpTy> {
return failure();
Value resultShape = reifiedResultShapes[dimValue.getResultNumber()];
- auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
- if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
+ auto resultShapeType = dyn_cast<RankedTensorType>(resultShape.getType());
+ if (!resultShapeType || !isa<IndexType>(resultShapeType.getElementType()))
return failure();
Location loc = dimOp->getLoc();
@@ -82,7 +82,7 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy dimOp,
PatternRewriter &rewriter) const override {
- OpResult dimValue = dimOp.getSource().template dyn_cast<OpResult>();
+ OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
if (!dimValue)
return failure();
std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 9ffb315587e3..05a069d98ef3 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -38,14 +38,14 @@ struct CastOpInterface
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
Location loc) const {
auto castOp = cast<CastOp>(op);
- auto srcType = castOp.getSource().getType().cast<BaseMemRefType>();
+ auto srcType = cast<BaseMemRefType>(castOp.getSource().getType());
// Nothing to check if the result is an unranked memref.
- auto resultType = castOp.getType().dyn_cast<MemRefType>();
+ auto resultType = dyn_cast<MemRefType>(castOp.getType());
if (!resultType)
return;
- if (srcType.isa<UnrankedMemRefType>()) {
+ if (isa<UnrankedMemRefType>(srcType)) {
// Check rank.
Value srcRank = builder.create<RankOp>(loc, castOp.getSource());
Value resultRank =
@@ -75,7 +75,7 @@ struct CastOpInterface
// Check dimension sizes.
for (const auto &it : llvm::enumerate(resultType.getShape())) {
// Static dim size -> static/dynamic dim size does not need verification.
- if (auto rankedSrcType = srcType.dyn_cast<MemRefType>())
+ if (auto rankedSrcType = dyn_cast<MemRefType>(srcType))
if (!rankedSrcType.isDynamicDim(it.index()))
continue;
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
index 292738de4b52..b9dd174a6b25 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
@@ -42,7 +42,7 @@ struct MmaSyncF32ToTF32Pattern : public OpRewritePattern<nvgpu::MmaSyncOp> {
Location location = op->getLoc();
if (op->hasAttr(op.getTf32EnabledAttrName()) ||
- !op.getMatrixA().getType().cast<VectorType>().getElementType().isF32())
+ !cast<VectorType>(op.getMatrixA().getType()).getElementType().isF32())
return failure();
if (precision == MmaSyncF32Lowering::Unkown)
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
index 07e9ae9f8650..486c786892c2 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp
@@ -180,7 +180,7 @@ getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
mlir::LogicalResult
mlir::nvgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
Value memrefValue) {
- auto memRefType = memrefValue.getType().dyn_cast<MemRefType>();
+ auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
if (!memRefType || !NVGPUDialect::hasSharedMemoryAddressSpace(memRefType))
return failure();
diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
index 7525f9f57bc5..5a0018c31517 100644
--- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
+++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
@@ -63,7 +63,7 @@ FailureOr<WarpMatrixInfo> nvgpu::getWarpMatrixInfo(Operation *op) {
info.vectorType = writeOp.getVectorType();
} else if (isa<vector::TransferReadOp, vector::ContractionOp,
vector::ExtractStridedSliceOp, arith::ConstantOp>(op)) {
- info.vectorType = op->getResult(0).getType().cast<VectorType>();
+ info.vectorType = cast<VectorType>(op->getResult(0).getType());
} else {
return op->emitError()
<< "unhandled operation type in nvgpu.mma.sync conversion path";
diff --git a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
index ddd8ae0c0fd3..5ee53eaad585 100644
--- a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
+++ b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
@@ -14,13 +14,13 @@ using namespace mlir;
using namespace mlir::quant;
static bool isQuantizablePrimitiveType(Type inputType) {
- return inputType.isa<FloatType>();
+ return isa<FloatType>(inputType);
}
ExpressedToQuantizedConverter
ExpressedToQuantizedConverter::forInputType(Type inputType) {
- if (inputType.isa<TensorType, VectorType>()) {
- Type elementType = inputType.cast<ShapedType>().getElementType();
+ if (isa<TensorType, VectorType>(inputType)) {
+ Type elementType = cast<ShapedType>(inputType).getElementType();
if (!isQuantizablePrimitiveType(elementType))
return ExpressedToQuantizedConverter{inputType, nullptr};
return ExpressedToQuantizedConverter{inputType, elementType};
@@ -34,11 +34,11 @@ ExpressedToQuantizedConverter::forInputType(Type inputType) {
Type ExpressedToQuantizedConverter::convert(QuantizedType elementalType) const {
assert(expressedType && "convert() on unsupported conversion");
- if (auto tensorType = inputType.dyn_cast<RankedTensorType>())
+ if (auto tensorType = dyn_cast<RankedTensorType>(inputType))
return RankedTensorType::get(tensorType.getShape(), elementalType);
- if (auto tensorType = inputType.dyn_cast<UnrankedTensorType>())
+ if (auto tensorType = dyn_cast<UnrankedTensorType>(inputType))
return UnrankedTensorType::get(elementalType);
- if (auto vectorType = inputType.dyn_cast<VectorType>())
+ if (auto vectorType = dyn_cast<VectorType>(inputType))
return VectorType::get(vectorType.getShape(), elementalType);
// If the expressed types match, just use the new elemental type.
@@ -50,7 +50,7 @@ Type ExpressedToQuantizedConverter::convert(QuantizedType elementalType) const {
ElementsAttr
UniformQuantizedPerAxisValueConverter::convert(Attribute realValue) {
- if (auto attr = realValue.dyn_cast<DenseFPElementsAttr>()) {
+ if (auto attr = dyn_cast<DenseFPElementsAttr>(realValue)) {
return convert(attr);
}
// TODO: handles sparse elements attribute
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 18425dea7b19..2da7473bf659 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -49,7 +49,7 @@ transform::GetParentForOp::apply(transform::TransformResults &results,
}
parents.insert(loop);
}
- results.set(getResult().cast<OpResult>(), parents.getArrayRef());
+ results.set(cast<OpResult>(getResult()), parents.getArrayRef());
return DiagnosedSilenceableFailure::success();
}
@@ -116,8 +116,8 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,
functions.push_back(*outlined);
calls.push_back(call);
}
- results.set(getFunction().cast<OpResult>(), functions);
- results.set(getCall().cast<OpResult>(), calls);
+ results.set(cast<OpResult>(getFunction()), functions);
+ results.set(cast<OpResult>(getCall()), calls);
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 13f0d769ef4c..ad395a9ac457 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -30,8 +30,8 @@ namespace {
/// Helper function for loop bufferization. Cast the given buffer to the given
/// memref type.
static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
- assert(type.isa<BaseMemRefType>() && "expected BaseMemRefType");
- assert(buffer.getType().isa<BaseMemRefType>() && "expected BaseMemRefType");
+ assert(isa<BaseMemRefType>(type) && "expected BaseMemRefType");
+ assert(isa<BaseMemRefType>(buffer.getType()) && "expected BaseMemRefType");
// If the buffer already has the correct type, no cast is needed.
if (buffer.getType() == type)
return buffer;
@@ -78,7 +78,7 @@ struct ConditionOpInterface
SmallVector<Value> newArgs;
for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
Value value = it.value();
- if (value.getType().isa<TensorType>()) {
+ if (isa<TensorType>(value.getType())) {
FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
if (failed(maybeBuffer))
return failure();
@@ -141,7 +141,7 @@ struct ExecuteRegionOpInterface
rewriter.setInsertionPointAfter(newOp);
SmallVector<Value> newResults;
for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
- if (it.value().isa<TensorType>()) {
+ if (isa<TensorType>(it.value())) {
newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
executeRegionOp.getLoc(), newOp->getResult(it.index())));
} else {
@@ -183,7 +183,7 @@ struct IfOpInterface
// Compute bufferized result types.
SmallVector<Type> newTypes;
for (Value result : ifOp.getResults()) {
- if (!result.getType().isa<TensorType>()) {
+ if (!isa<TensorType>(result.getType())) {
newTypes.push_back(result.getType());
continue;
}
@@ -218,13 +218,13 @@ struct IfOpInterface
assert(value.getDefiningOp() == op && "invalid valid");
// Determine buffer types of the true/false branches.
- auto opResult = value.cast<OpResult>();
+ auto opResult = cast<OpResult>(value);
auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
BaseMemRefType thenBufferType, elseBufferType;
- if (thenValue.getType().isa<BaseMemRefType>()) {
+ if (isa<BaseMemRefType>(thenValue.getType())) {
// True branch was already bufferized.
- thenBufferType = thenValue.getType().cast<BaseMemRefType>();
+ thenBufferType = cast<BaseMemRefType>(thenValue.getType());
} else {
auto maybeBufferType =
bufferization::getBufferType(thenValue, options, fixedTypes);
@@ -232,9 +232,9 @@ struct IfOpInterface
return failure();
thenBufferType = *maybeBufferType;
}
- if (elseValue.getType().isa<BaseMemRefType>()) {
+ if (isa<BaseMemRefType>(elseValue.getType())) {
// False branch was already bufferized.
- elseBufferType = elseValue.getType().cast<BaseMemRefType>();
+ elseBufferType = cast<BaseMemRefType>(elseValue.getType());
} else {
auto maybeBufferType =
bufferization::getBufferType(elseValue, options, fixedTypes);
@@ -253,7 +253,7 @@ struct IfOpInterface
// Layout maps are different: Promote to fully dynamic layout map.
return getMemRefTypeWithFullyDynamicLayout(
- opResult.getType().cast<TensorType>(), thenBufferType.getMemorySpace());
+ cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
}
};
@@ -262,7 +262,7 @@ struct IfOpInterface
static DenseSet<int64_t> getTensorIndices(ValueRange values) {
DenseSet<int64_t> result;
for (const auto &it : llvm::enumerate(values))
- if (it.value().getType().isa<TensorType>())
+ if (isa<TensorType>(it.value().getType()))
result.insert(it.index());
return result;
}
@@ -275,8 +275,8 @@ DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
DenseSet<int64_t> result;
for (unsigned int i = 0; i < minSize; ++i) {
- if (!bbArgs[i].getType().isa<TensorType>() ||
- !yieldedValues[i].getType().isa<TensorType>())
+ if (!isa<TensorType>(bbArgs[i].getType()) ||
+ !isa<TensorType>(yieldedValues[i].getType()))
continue;
if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
result.insert(i);
@@ -291,7 +291,7 @@ getBuffers(RewriterBase &rewriter, MutableArrayRef<OpOperand> operands,
const BufferizationOptions &options) {
SmallVector<Value> result;
for (OpOperand &opOperand : operands) {
- if (opOperand.get().getType().isa<TensorType>()) {
+ if (isa<TensorType>(opOperand.get().getType())) {
FailureOr<Value> resultBuffer =
getBuffer(rewriter, opOperand.get(), options);
if (failed(resultBuffer))
@@ -361,9 +361,9 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
// Compute the buffer type of the yielded value.
BaseMemRefType yieldedValueBufferType;
- if (yieldedValue.getType().isa<BaseMemRefType>()) {
+ if (isa<BaseMemRefType>(yieldedValue.getType())) {
// scf.yield was already bufferized.
- yieldedValueBufferType = yieldedValue.getType().cast<BaseMemRefType>();
+ yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
} else {
auto maybeBufferType =
bufferization::getBufferType(yieldedValue, options, newFixedTypes);
@@ -379,7 +379,7 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
// If there is a mismatch between the yielded buffer type and the iter_arg
// buffer type, the buffer type must be promoted to a fully dynamic layout
// map.
- auto yieldedRanked = yieldedValueBufferType.cast<MemRefType>();
+ auto yieldedRanked = cast<MemRefType>(yieldedValueBufferType);
#ifndef NDEBUG
auto iterRanked = initArgBufferType->cast<MemRefType>();
assert(llvm::equal(yieldedRanked.getShape(), iterRanked.getShape()) &&
@@ -388,7 +388,7 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
"expected same memory space");
#endif // NDEBUG
return getMemRefTypeWithFullyDynamicLayout(
- iterArg.getType().cast<RankedTensorType>(),
+ cast<RankedTensorType>(iterArg.getType()),
yieldedRanked.getMemorySpace());
}
@@ -516,16 +516,16 @@ struct ForOpInterface
const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
auto forOp = cast<scf::ForOp>(op);
assert(getOwnerOfValue(value) == op && "invalid value");
- assert(value.getType().isa<TensorType>() && "expected tensor type");
+ assert(isa<TensorType>(value.getType()) && "expected tensor type");
// Get result/argument number.
unsigned resultNum;
- if (auto bbArg = value.dyn_cast<BlockArgument>()) {
+ if (auto bbArg = dyn_cast<BlockArgument>(value)) {
resultNum =
forOp.getResultForOpOperand(forOp.getOpOperandForRegionIterArg(bbArg))
.getResultNumber();
} else {
- resultNum = value.cast<OpResult>().getResultNumber();
+ resultNum = cast<OpResult>(value).getResultNumber();
}
// Compute the bufferized type.
@@ -560,7 +560,7 @@ struct ForOpInterface
Value initArg = it.value();
Value result = forOp->getResult(it.index());
// If the type is not a tensor, bufferization doesn't need to touch it.
- if (!result.getType().isa<TensorType>()) {
+ if (!isa<TensorType>(result.getType())) {
castedInitArgs.push_back(initArg);
continue;
}
@@ -611,7 +611,7 @@ struct ForOpInterface
auto yieldOp =
cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
for (OpResult opResult : op->getOpResults()) {
- if (!opResult.getType().isa<TensorType>())
+ if (!isa<TensorType>(opResult.getType()))
continue;
// Note: This is overly strict. We should check for aliasing bufferized
@@ -736,7 +736,7 @@ struct WhileOpInterface
for (int64_t idx = 0;
idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
Value value = conditionOp.getArgs()[idx];
- if (!value.getType().isa<TensorType>() ||
+ if (!isa<TensorType>(value.getType()) ||
(equivalentYieldsAfter.contains(idx) &&
equivalentYieldsBefore.contains(idx))) {
beforeYieldValues.push_back(value);
@@ -786,7 +786,7 @@ struct WhileOpInterface
Value initArg = it.value();
Value beforeArg = whileOp.getBeforeArguments()[it.index()];
// If the type is not a tensor, bufferization doesn't need to touch it.
- if (!beforeArg.getType().isa<TensorType>()) {
+ if (!isa<TensorType>(beforeArg.getType())) {
castedInitArgs.push_back(initArg);
continue;
}
@@ -799,7 +799,7 @@ struct WhileOpInterface
// The result types of a WhileOp are the same as the "after" bbArg types.
SmallVector<Type> argsTypesAfter = llvm::to_vector(
llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
- if (!bbArg.getType().isa<TensorType>())
+ if (!isa<TensorType>(bbArg.getType()))
return bbArg.getType();
// TODO: error handling
return bufferization::getBufferType(bbArg, options)->cast<Type>();
@@ -848,10 +848,10 @@ struct WhileOpInterface
const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
auto whileOp = cast<scf::WhileOp>(op);
assert(getOwnerOfValue(value) == op && "invalid value");
- assert(value.getType().isa<TensorType>() && "expected tensor type");
+ assert(isa<TensorType>(value.getType()) && "expected tensor type");
// Case 1: Block argument of the "before" region.
- if (auto bbArg = value.dyn_cast<BlockArgument>()) {
+ if (auto bbArg = dyn_cast<BlockArgument>(value)) {
if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) {
Value initArg = whileOp.getInits()[bbArg.getArgNumber()];
auto yieldOp = whileOp.getYieldOp();
@@ -865,18 +865,18 @@ struct WhileOpInterface
// The bufferized "after" bbArg type can be directly computed from the
// bufferized "before" bbArg type.
unsigned resultNum;
- if (auto opResult = value.dyn_cast<OpResult>()) {
+ if (auto opResult = dyn_cast<OpResult>(value)) {
resultNum = opResult.getResultNumber();
- } else if (value.cast<BlockArgument>().getOwner()->getParent() ==
+ } else if (cast<BlockArgument>(value).getOwner()->getParent() ==
&whileOp.getAfter()) {
- resultNum = value.cast<BlockArgument>().getArgNumber();
+ resultNum = cast<BlockArgument>(value).getArgNumber();
} else {
llvm_unreachable("invalid value");
}
Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
- if (!conditionYieldedVal.getType().isa<TensorType>()) {
+ if (!isa<TensorType>(conditionYieldedVal.getType())) {
// scf.condition was already bufferized.
- return conditionYieldedVal.getType().cast<BaseMemRefType>();
+ return cast<BaseMemRefType>(conditionYieldedVal.getType());
}
return bufferization::getBufferType(conditionYieldedVal, options,
fixedTypes);
@@ -902,7 +902,7 @@ struct WhileOpInterface
auto conditionOp = whileOp.getConditionOp();
for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
- if (!it.value().getType().isa<TensorType>())
+ if (!isa<TensorType>(it.value().getType()))
continue;
if (!state.areEquivalentBufferizedValues(
it.value(), conditionOp->getBlock()->getArgument(it.index())))
@@ -913,7 +913,7 @@ struct WhileOpInterface
auto yieldOp = whileOp.getYieldOp();
for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
- if (!it.value().getType().isa<TensorType>())
+ if (!isa<TensorType>(it.value().getType()))
continue;
if (!state.areEquivalentBufferizedValues(
it.value(), yieldOp->getBlock()->getArgument(it.index())))
@@ -971,7 +971,7 @@ struct YieldOpInterface
SmallVector<Value> newResults;
for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
Value value = it.value();
- if (value.getType().isa<TensorType>()) {
+ if (isa<TensorType>(value.getType())) {
FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
if (failed(maybeBuffer))
return failure();
@@ -1110,7 +1110,7 @@ struct ForallOpInterface
const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
auto forallOp = cast<ForallOp>(op);
- if (auto bbArg = value.dyn_cast<BlockArgument>())
+ if (auto bbArg = dyn_cast<BlockArgument>(value))
// A tensor block argument has the same bufferized type as the
// corresponding output operand.
return bufferization::getBufferType(
@@ -1119,8 +1119,8 @@ struct ForallOpInterface
// The bufferized result type is the same as the bufferized type of the
// corresponding output operand.
return bufferization::getBufferType(
- forallOp.getOutputs()[value.cast<OpResult>().getResultNumber()],
- options, fixedTypes);
+ forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
+ fixedTypes);
}
bool isRepetitiveRegion(Operation *op, unsigned index) const {
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
index 2450a0e5fb34..99591493d132 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
@@ -43,7 +43,7 @@ static bool isShapePreserving(ForOp forOp, int64_t arg) {
while (value) {
if (value == forOp.getRegionIterArgs()[arg])
return true;
- OpResult opResult = value.dyn_cast<OpResult>();
+ OpResult opResult = dyn_cast<OpResult>(value);
if (!opResult)
return false;
@@ -91,7 +91,7 @@ struct DimOfIterArgFolder : public OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy dimOp,
PatternRewriter &rewriter) const override {
- auto blockArg = dimOp.getSource().template dyn_cast<BlockArgument>();
+ auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
if (!blockArg)
return failure();
auto forOp = dyn_cast<ForOp>(blockArg.getParentBlock()->getParentOp());
@@ -139,7 +139,7 @@ struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> {
auto forOp = dimOp.getSource().template getDefiningOp<scf::ForOp>();
if (!forOp)
return failure();
- auto opResult = dimOp.getSource().template cast<OpResult>();
+ auto opResult = cast<OpResult>(dimOp.getSource());
unsigned resultNumber = opResult.getResultNumber();
if (!isShapePreserving(forOp, resultNumber))
return failure();
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 6a9f72521142..a85985b84a03 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -164,8 +164,7 @@ cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
clone->walk([&](Operation *nested) {
for (OpOperand &operand : nested->getOpOperands()) {
Operation *def = operand.get().getDefiningOp();
- if ((def && !clone->isAncestor(def)) ||
- operand.get().isa<BlockArgument>())
+ if ((def && !clone->isAncestor(def)) || isa<BlockArgument>(operand.get()))
callback(&operand);
}
});
@@ -346,7 +345,7 @@ void LoopPipelinerInternal::createKernel(
rewriter.setInsertionPointAfter(newOp);
continue;
}
- auto arg = operand->get().dyn_cast<BlockArgument>();
+ auto arg = dyn_cast<BlockArgument>(operand->get());
if (arg && arg.getOwner() == forOp.getBody()) {
// If the value is a loop carried value coming from stage N + 1 remap,
// it will become a direct use.
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 131e8216ef5d..224bec3b26d2 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -496,7 +496,7 @@ getUntiledProducerFromSliceSource(OpOperand *source,
ArrayRef<scf::ForOp> loops) {
std::optional<OpOperand *> destinationIterArg;
auto loopIt = loops.rbegin();
- while (auto iterArg = source->get().dyn_cast<BlockArgument>()) {
+ while (auto iterArg = dyn_cast<BlockArgument>(source->get())) {
scf::ForOp loop = *loopIt;
if (iterArg.getOwner()->getParentOp() != loop)
break;
@@ -505,7 +505,7 @@ getUntiledProducerFromSliceSource(OpOperand *source,
}
if (loopIt == loops.rend())
destinationIterArg = source;
- return {source->get().dyn_cast<OpResult>(), destinationIterArg};
+ return {dyn_cast<OpResult>(source->get()), destinationIterArg};
}
/// Implementation of fusing producer of a single slice by computing the
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
index f154840b6f65..c22cb6710a7e 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
@@ -42,8 +42,8 @@ public:
PatternRewriter &rewriter) const override {
SmallVector<NamedAttribute, 4> globalVarAttrs;
- auto ptrType = op.getType().cast<spirv::PointerType>();
- auto pointeeType = ptrType.getPointeeType().cast<spirv::StructType>();
+ auto ptrType = cast<spirv::PointerType>(op.getType());
+ auto pointeeType = cast<spirv::StructType>(ptrType.getPointeeType());
spirv::StructType structType = VulkanLayoutUtils::decorateType(pointeeType);
if (!structType)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index c0ab2152675e..9f2755da0922 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -51,19 +51,19 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
// info create a variable of type !spirv.ptr<!spirv.struct<elementType>>. If
// not it must already be a !spirv.ptr<!spirv.struct<...>>.
auto varType = funcOp.getFunctionType().getInput(argIndex);
- if (varType.cast<spirv::SPIRVType>().isScalarOrVector()) {
+ if (cast<spirv::SPIRVType>(varType).isScalarOrVector()) {
auto storageClass = abiInfo.getStorageClass();
if (!storageClass)
return nullptr;
varType =
spirv::PointerType::get(spirv::StructType::get(varType), *storageClass);
}
- auto varPtrType = varType.cast<spirv::PointerType>();
- auto varPointeeType = varPtrType.getPointeeType().cast<spirv::StructType>();
+ auto varPtrType = cast<spirv::PointerType>(varType);
+ auto varPointeeType = cast<spirv::StructType>(varPtrType.getPointeeType());
// Set the offset information.
varPointeeType =
- VulkanLayoutUtils::decorateType(varPointeeType).cast<spirv::StructType>();
+ cast<spirv::StructType>(VulkanLayoutUtils::decorateType(varPointeeType));
if (!varPointeeType)
return nullptr;
@@ -98,7 +98,7 @@ getInterfaceVariables(spirv::FuncOp funcOp,
// Starting with version 1.4, the interface’s storage classes are all
// storage classes used in declaring all global variables referenced by the
// entry point’s call tree." We should consider the target environment here.
- switch (var.getType().cast<spirv::PointerType>().getStorageClass()) {
+ switch (cast<spirv::PointerType>(var.getType()).getStorageClass()) {
case spirv::StorageClass::Input:
case spirv::StorageClass::Output:
interfaceVarSet.insert(var.getOperation());
@@ -247,7 +247,7 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
// at the start of the function. It is probably better to do the load just
// before the use. There might be multiple loads and currently there is no
// easy way to replace all uses with a sequence of operations.
- if (argType.value().cast<spirv::SPIRVType>().isScalarOrVector()) {
+ if (cast<spirv::SPIRVType>(argType.value()).isScalarOrVector()) {
auto zero =
spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter);
auto loadPtr = rewriter.create<spirv::AccessChainOp>(
@@ -287,7 +287,7 @@ void LowerABIAttributesPass::runOnOperation() {
typeConverter.addSourceMaterialization([](OpBuilder &builder,
spirv::PointerType type,
ValueRange inputs, Location loc) {
- if (inputs.size() != 1 || !inputs[0].getType().isa<spirv::PointerType>())
+ if (inputs.size() != 1 || !isa<spirv::PointerType>(inputs[0].getType()))
return Value();
return builder.create<spirv::BitcastOp>(loc, type, inputs[0]).getResult();
});
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
index 51c36bd12db1..f38282f57a2c 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
@@ -84,15 +84,13 @@ void RewriteInsertsPass::runOnOperation() {
LogicalResult RewriteInsertsPass::collectInsertionChain(
spirv::CompositeInsertOp op,
SmallVectorImpl<spirv::CompositeInsertOp> &insertions) {
- auto indicesArrayAttr = op.getIndices().cast<ArrayAttr>();
+ auto indicesArrayAttr = cast<ArrayAttr>(op.getIndices());
// TODO: handle nested composite object.
if (indicesArrayAttr.size() == 1) {
- auto numElements = op.getComposite()
- .getType()
- .cast<spirv::CompositeType>()
+ auto numElements = cast<spirv::CompositeType>(op.getComposite().getType())
.getNumElements();
- auto index = indicesArrayAttr[0].cast<IntegerAttr>().getInt();
+ auto index = cast<IntegerAttr>(indicesArrayAttr[0]).getInt();
// Need a last index to collect a sequential chain.
if (index + 1 != numElements)
return failure();
@@ -109,9 +107,9 @@ LogicalResult RewriteInsertsPass::collectInsertionChain(
return failure();
--index;
- indicesArrayAttr = op.getIndices().cast<ArrayAttr>();
+ indicesArrayAttr = cast<ArrayAttr>(op.getIndices());
if ((indicesArrayAttr.size() != 1) ||
- (indicesArrayAttr[0].cast<IntegerAttr>().getInt() != index))
+ (cast<IntegerAttr>(indicesArrayAttr[0]).getInt() != index))
return failure();
}
}
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 5a5cdfe34194..793b02520f23 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -139,7 +139,7 @@ bool SPIRVTypeConverter::allows(spirv::Capability capability) {
// SPIR-V dialect. Keeping it local till the use case arises.
static std::optional<int64_t>
getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
- if (type.isa<spirv::ScalarType>()) {
+ if (isa<spirv::ScalarType>(type)) {
auto bitWidth = type.getIntOrFloatBitWidth();
// According to the SPIR-V spec:
// "There is no physical size or bit pattern defined for values with boolean
@@ -152,21 +152,21 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
return bitWidth / 8;
}
- if (auto complexType = type.dyn_cast<ComplexType>()) {
+ if (auto complexType = dyn_cast<ComplexType>(type)) {
auto elementSize = getTypeNumBytes(options, complexType.getElementType());
if (!elementSize)
return std::nullopt;
return 2 * *elementSize;
}
- if (auto vecType = type.dyn_cast<VectorType>()) {
+ if (auto vecType = dyn_cast<VectorType>(type)) {
auto elementSize = getTypeNumBytes(options, vecType.getElementType());
if (!elementSize)
return std::nullopt;
return vecType.getNumElements() * *elementSize;
}
- if (auto memRefType = type.dyn_cast<MemRefType>()) {
+ if (auto memRefType = dyn_cast<MemRefType>(type)) {
// TODO: Layout should also be controlled by the ABI attributes. For now
// using the layout from MemRef.
int64_t offset;
@@ -198,7 +198,7 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
return (offset + memrefSize) * *elementSize;
}
- if (auto tensorType = type.dyn_cast<TensorType>()) {
+ if (auto tensorType = dyn_cast<TensorType>(type)) {
if (!tensorType.hasStaticShape())
return std::nullopt;
@@ -246,12 +246,12 @@ convertScalarType(const spirv::TargetEnv &targetEnv,
return nullptr;
}
- if (auto floatType = type.dyn_cast<FloatType>()) {
+ if (auto floatType = dyn_cast<FloatType>(type)) {
LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
return Builder(targetEnv.getContext()).getF32Type();
}
- auto intType = type.cast<IntegerType>();
+ auto intType = cast<IntegerType>(type);
LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
return IntegerType::get(targetEnv.getContext(), /*width=*/32,
intType.getSignedness());
@@ -319,8 +319,8 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
// Get extension and capability requirements for the given type.
SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
- type.cast<spirv::CompositeType>().getExtensions(extensions, storageClass);
- type.cast<spirv::CompositeType>().getCapabilities(capabilities, storageClass);
+ cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
+ cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
// If all requirements are met, then we can accept this type as-is.
if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
@@ -415,8 +415,8 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
<< "using non-8-bit storage for bool types unimplemented");
return nullptr;
}
- auto elementType = IntegerType::get(type.getContext(), numBoolBits)
- .dyn_cast<spirv::ScalarType>();
+ auto elementType = dyn_cast<spirv::ScalarType>(
+ IntegerType::get(type.getContext(), numBoolBits));
if (!elementType)
return nullptr;
Type arrayElemType =
@@ -487,7 +487,7 @@ static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
const SPIRVConversionOptions &options,
MemRefType type) {
- auto attr = type.getMemorySpace().dyn_cast_or_null<spirv::StorageClassAttr>();
+ auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
if (!attr) {
LLVM_DEBUG(
llvm::dbgs()
@@ -499,7 +499,7 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
}
spirv::StorageClass storageClass = attr.getValue();
- if (type.getElementType().isa<IntegerType>()) {
+ if (isa<IntegerType>(type.getElementType())) {
if (type.getElementTypeBitWidth() == 1)
return convertBoolMemrefType(targetEnv, options, type, storageClass);
if (type.getElementTypeBitWidth() < 8)
@@ -508,17 +508,17 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
Type arrayElemType;
Type elementType = type.getElementType();
- if (auto vecType = elementType.dyn_cast<VectorType>()) {
+ if (auto vecType = dyn_cast<VectorType>(elementType)) {
arrayElemType =
convertVectorType(targetEnv, options, vecType, storageClass);
- } else if (auto complexType = elementType.dyn_cast<ComplexType>()) {
+ } else if (auto complexType = dyn_cast<ComplexType>(elementType)) {
arrayElemType =
convertComplexType(targetEnv, options, complexType, storageClass);
- } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
+ } else if (auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) {
arrayElemType =
convertScalarType(targetEnv, options, scalarType, storageClass);
- } else if (auto indexType = elementType.dyn_cast<IndexType>()) {
- type = convertIndexElementType(type, options).cast<MemRefType>();
+ } else if (auto indexType = dyn_cast<IndexType>(elementType)) {
+ type = cast<MemRefType>(convertIndexElementType(type, options));
arrayElemType = type.getElementType();
} else {
LLVM_DEBUG(
@@ -583,7 +583,7 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
addConversion([this](IntegerType intType) -> std::optional<Type> {
- if (auto scalarType = intType.dyn_cast<spirv::ScalarType>())
+ if (auto scalarType = dyn_cast<spirv::ScalarType>(intType))
return convertScalarType(this->targetEnv, this->options, scalarType);
if (intType.getWidth() < 8)
return convertSubByteIntegerType(this->options, intType);
@@ -591,7 +591,7 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
});
addConversion([this](FloatType floatType) -> std::optional<Type> {
- if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>())
+ if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
return convertScalarType(this->targetEnv, this->options, scalarType);
return Type();
});
@@ -784,7 +784,7 @@ static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
unsigned elementCount) {
for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
- auto ptrType = varOp.getType().dyn_cast<spirv::PointerType>();
+ auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
if (!ptrType)
continue;
@@ -792,10 +792,9 @@ static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
// block statically used per shader entry point." So we should always reuse
// the existing one.
if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
- auto numElements = ptrType.getPointeeType()
- .cast<spirv::StructType>()
- .getElementType(0)
- .cast<spirv::ArrayType>()
+ auto numElements = cast<spirv::ArrayType>(
+ cast<spirv::StructType>(ptrType.getPointeeType())
+ .getElementType(0))
.getNumElements();
if (numElements == elementCount)
return varOp;
@@ -926,8 +925,8 @@ Value mlir::spirv::getOpenCLElementPtr(SPIRVTypeConverter &typeConverter,
linearizeIndex(indices, strides, offset, indexType, loc, builder);
}
Type pointeeType =
- basePtr.getType().cast<spirv::PointerType>().getPointeeType();
- if (pointeeType.isa<spirv::ArrayType>()) {
+ cast<spirv::PointerType>(basePtr.getType()).getPointeeType();
+ if (isa<spirv::ArrayType>(pointeeType)) {
linearizedIndices.push_back(linearIndex);
return builder.create<spirv::AccessChainOp>(loc, basePtr,
linearizedIndices);
@@ -1015,7 +1014,7 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
// Ensure that all types have been converted to SPIRV types.
if (llvm::any_of(valueTypes,
- [](Type t) { return !t.isa<spirv::SPIRVType>(); }))
+ [](Type t) { return !isa<spirv::SPIRVType>(t); }))
return false;
// Special treatment for global variables, whose type requirements are
@@ -1029,13 +1028,13 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
for (Type valueType : valueTypes) {
typeExtensions.clear();
- valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
+ cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
typeExtensions)))
return false;
typeCapabilities.clear();
- valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
+ cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
typeCapabilities)))
return false;
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
index 3cd4937e96f2..44fea8678559 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
@@ -41,7 +41,7 @@ namespace {
//===----------------------------------------------------------------------===//
Attribute getScalarOrSplatAttr(Type type, int64_t value) {
APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value);
- if (auto intTy = type.dyn_cast<IntegerType>())
+ if (auto intTy = dyn_cast<IntegerType>(type))
return IntegerAttr::get(intTy, sizedValue);
return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue);
@@ -149,7 +149,7 @@ struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> {
// Currently, WGSL only supports 32-bit integer types. Any other integer
// types should already have been promoted/demoted to i32.
- auto elemTy = getElementTypeOrSelf(lhs.getType()).cast<IntegerType>();
+ auto elemTy = cast<IntegerType>(getElementTypeOrSelf(lhs.getType()));
if (elemTy.getIntOrFloatBitWidth() != 32)
return rewriter.notifyMatchFailure(
loc,
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
index 97f16d1b1b95..ea856c748677 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
@@ -65,16 +65,16 @@ static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) {
/// `!spirv.ptr<!spirv.struct<!spirv.rtarray<...>>>`. Returns null type
/// otherwise.
static Type getRuntimeArrayElementType(Type type) {
- auto ptrType = type.dyn_cast<spirv::PointerType>();
+ auto ptrType = dyn_cast<spirv::PointerType>(type);
if (!ptrType)
return {};
- auto structType = ptrType.getPointeeType().dyn_cast<spirv::StructType>();
+ auto structType = dyn_cast<spirv::StructType>(ptrType.getPointeeType());
if (!structType || structType.getNumElements() != 1)
return {};
auto rtArrayType =
- structType.getElementType(0).dyn_cast<spirv::RuntimeArrayType>();
+ dyn_cast<spirv::RuntimeArrayType>(structType.getElementType(0));
if (!rtArrayType)
return {};
@@ -97,7 +97,7 @@ deduceCanonicalResource(ArrayRef<spirv::SPIRVType> types) {
for (const auto &indexedTypes : llvm::enumerate(types)) {
spirv::SPIRVType type = indexedTypes.value();
assert(type.isScalarOrVector());
- if (auto vectorType = type.dyn_cast<VectorType>()) {
+ if (auto vectorType = dyn_cast<VectorType>(type)) {
if (vectorType.getNumElements() % 2 != 0)
return std::nullopt; // Odd-sized vector has special layout
// requirements.
@@ -277,7 +277,7 @@ void ResourceAliasAnalysis::recordIfUnifiable(
if (!elementType)
return; // Unexpected resource variable type.
- auto type = elementType.cast<spirv::SPIRVType>();
+ auto type = cast<spirv::SPIRVType>(elementType);
if (!type.isScalarOrVector())
return; // Unexpected resource element type.
@@ -370,7 +370,7 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
Location loc = acOp.getLoc();
- if (srcElemType.isIntOrFloat() && dstElemType.isa<VectorType>()) {
+ if (srcElemType.isIntOrFloat() && isa<VectorType>(dstElemType)) {
// The source indices are for a buffer with scalar element types. Rewrite
// them into a buffer with vector element types. We need to scale the last
// index for the vector as a whole, then add one level of index for inside
@@ -398,7 +398,7 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
}
if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
- (srcElemType.isa<VectorType>() && dstElemType.isa<VectorType>())) {
+ (isa<VectorType>(srcElemType) && isa<VectorType>(dstElemType))) {
// The source indices are for a buffer with larger bitwidth scalar/vector
// element types. Rewrite them into a buffer with smaller bitwidth element
// types. We only need to scale the last index.
@@ -433,10 +433,10 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
LogicalResult
matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto srcPtrType = loadOp.getPtr().getType().cast<spirv::PointerType>();
- auto srcElemType = srcPtrType.getPointeeType().cast<spirv::SPIRVType>();
- auto dstPtrType = adaptor.getPtr().getType().cast<spirv::PointerType>();
- auto dstElemType = dstPtrType.getPointeeType().cast<spirv::SPIRVType>();
+ auto srcPtrType = cast<spirv::PointerType>(loadOp.getPtr().getType());
+ auto srcElemType = cast<spirv::SPIRVType>(srcPtrType.getPointeeType());
+ auto dstPtrType = cast<spirv::PointerType>(adaptor.getPtr().getType());
+ auto dstElemType = cast<spirv::SPIRVType>(dstPtrType.getPointeeType());
Location loc = loadOp.getLoc();
auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.getPtr());
@@ -454,7 +454,7 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
}
if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
- (srcElemType.isa<VectorType>() && dstElemType.isa<VectorType>())) {
+ (isa<VectorType>(srcElemType) && isa<VectorType>(dstElemType))) {
// The source and destination have scalar types of different bitwidths, or
// vector types of different component counts. For such cases, we load
// multiple smaller bitwidth values and construct a larger bitwidth one.
@@ -495,13 +495,13 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
// type.
Type vectorType = srcElemType;
- if (!srcElemType.isa<VectorType>())
+ if (!isa<VectorType>(srcElemType))
vectorType = VectorType::get({ratio}, dstElemType);
// If both the source and destination are vector types, we need to make
// sure the scalar type is the same for composite construction later.
- if (auto srcElemVecType = srcElemType.dyn_cast<VectorType>())
- if (auto dstElemVecType = dstElemType.dyn_cast<VectorType>()) {
+ if (auto srcElemVecType = dyn_cast<VectorType>(srcElemType))
+ if (auto dstElemVecType = dyn_cast<VectorType>(dstElemType)) {
if (srcElemVecType.getElementType() !=
dstElemVecType.getElementType()) {
int64_t count =
@@ -515,7 +515,7 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
Value vectorValue = rewriter.create<spirv::CompositeConstructOp>(
loc, vectorType, components);
- if (!srcElemType.isa<VectorType>())
+ if (!isa<VectorType>(srcElemType))
vectorValue =
rewriter.create<spirv::BitcastOp>(loc, srcElemType, vectorValue);
rewriter.replaceOp(loadOp, vectorValue);
@@ -534,9 +534,9 @@ struct ConvertStore : public ConvertAliasResource<spirv::StoreOp> {
matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcElemType =
- storeOp.getPtr().getType().cast<spirv::PointerType>().getPointeeType();
+ cast<spirv::PointerType>(storeOp.getPtr().getType()).getPointeeType();
auto dstElemType =
- adaptor.getPtr().getType().cast<spirv::PointerType>().getPointeeType();
+ cast<spirv::PointerType>(adaptor.getPtr().getType()).getPointeeType();
if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
return rewriter.notifyMatchFailure(storeOp, "not scalar type");
if (!areSameBitwidthScalarType(srcElemType, dstElemType))
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
index 6e09a848c494..095db6b815f5 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -159,13 +159,13 @@ void UpdateVCEPass::runOnOperation() {
SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
for (Type valueType : valueTypes) {
typeExtensions.clear();
- valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
+ cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
if (failed(checkAndUpdateExtensionRequirements(
op, targetEnv, typeExtensions, deducedExtensions)))
return WalkResult::interrupt();
typeCapabilities.clear();
- valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
+ cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
if (failed(checkAndUpdateCapabilityRequirements(
op, targetEnv, typeCapabilities, deducedCapabilities)))
return WalkResult::interrupt();
diff --git a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
index 67d61f820b62..b19495bc3744 100644
--- a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
@@ -53,7 +53,7 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
// must be a runtime array.
assert(memberSize != std::numeric_limits<Size>().max() ||
(i + 1 == e &&
- structType.getElementType(i).isa<spirv::RuntimeArrayType>()));
+ isa<spirv::RuntimeArrayType>(structType.getElementType(i))));
// According to the Vulkan spec:
// "A structure has a base alignment equal to the largest base alignment of
// any of its members."
@@ -79,23 +79,23 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
VulkanLayoutUtils::Size &alignment) {
- if (type.isa<spirv::ScalarType>()) {
+ if (isa<spirv::ScalarType>(type)) {
alignment = getScalarTypeAlignment(type);
// Vulkan spec does not specify any padding for a scalar type.
size = alignment;
return type;
}
- if (auto structType = type.dyn_cast<spirv::StructType>())
+ if (auto structType = dyn_cast<spirv::StructType>(type))
return decorateType(structType, size, alignment);
- if (auto arrayType = type.dyn_cast<spirv::ArrayType>())
+ if (auto arrayType = dyn_cast<spirv::ArrayType>(type))
return decorateType(arrayType, size, alignment);
- if (auto vectorType = type.dyn_cast<VectorType>())
+ if (auto vectorType = dyn_cast<VectorType>(type))
return decorateType(vectorType, size, alignment);
- if (auto arrayType = type.dyn_cast<spirv::RuntimeArrayType>()) {
+ if (auto arrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
size = std::numeric_limits<Size>().max();
return decorateType(arrayType, alignment);
}
- if (type.isa<spirv::PointerType>()) {
+ if (isa<spirv::PointerType>(type)) {
// TODO: Add support for `PhysicalStorageBufferAddresses`.
return nullptr;
}
@@ -161,13 +161,13 @@ VulkanLayoutUtils::getScalarTypeAlignment(Type scalarType) {
}
bool VulkanLayoutUtils::isLegalType(Type type) {
- auto ptrType = type.dyn_cast<spirv::PointerType>();
+ auto ptrType = dyn_cast<spirv::PointerType>(type);
if (!ptrType) {
return true;
}
auto storageClass = ptrType.getStorageClass();
- auto structType = ptrType.getPointeeType().dyn_cast<spirv::StructType>();
+ auto structType = dyn_cast<spirv::StructType>(ptrType.getPointeeType());
if (!structType) {
return true;
}
diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
index fc67fea1a493..4a567f48aeb4 100644
--- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -64,7 +64,7 @@ struct AssumingOpInterface
rewriter.setInsertionPointAfter(newOp);
SmallVector<Value> newResults;
for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) {
- if (it.value().isa<TensorType>()) {
+ if (isa<TensorType>(it.value())) {
newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
assumingOp.getLoc(), newOp->getResult(it.index())));
} else {
@@ -116,7 +116,7 @@ struct AssumingYieldOpInterface
auto yieldOp = cast<shape::AssumingYieldOp>(op);
SmallVector<Value> newResults;
for (Value value : yieldOp.getOperands()) {
- if (value.getType().isa<TensorType>()) {
+ if (isa<TensorType>(value.getType())) {
FailureOr<Value> buffer = getBuffer(rewriter, value, options);
if (failed(buffer))
return failure();
diff --git a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp
index f23a090a25a0..1a6f868cf21d 100644
--- a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp
@@ -133,7 +133,7 @@ void constructShapeFunc(
for (shape::WithOp withOp : allWithOps) {
Value value = withOp.getOperand();
Value shape = withOp.getShape();
- RankedTensorType rankedType = value.getType().dyn_cast<RankedTensorType>();
+ RankedTensorType rankedType = dyn_cast<RankedTensorType>(value.getType());
if (rankedType == nullptr)
continue;
diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
index 99a619cda7b6..990f8f7327d8 100644
--- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
@@ -41,7 +41,7 @@ getBufferizationOptions(bool analysisOnly) {
options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
const BufferizationOptions &options) {
return getMemRefTypeWithStaticIdentityLayout(
- value.getType().cast<TensorType>(), memorySpace);
+ cast<TensorType>(value.getType()), memorySpace);
};
if (analysisOnly) {
options.testAnalysisOnly = true;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 6fd55c779930..ace8a8867081 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -28,7 +28,7 @@ using namespace mlir::sparse_tensor;
static std::optional<std::pair<Value, Value>>
genSplitSparseConstant(OpBuilder &builder, Location loc, Value tensor) {
if (auto constOp = tensor.getDefiningOp<arith::ConstantOp>()) {
- if (auto a = constOp.getValue().dyn_cast<SparseElementsAttr>()) {
+ if (auto a = dyn_cast<SparseElementsAttr>(constOp.getValue())) {
auto coordinates = builder.create<arith::ConstantOp>(loc, a.getIndices());
auto values = builder.create<arith::ConstantOp>(loc, a.getValues());
return std::make_pair(coordinates, values);
@@ -94,7 +94,7 @@ OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) {
OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) {
if (tp.isIndex())
return OverheadType::kIndex;
- if (auto intTp = tp.dyn_cast<IntegerType>())
+ if (auto intTp = dyn_cast<IntegerType>(tp))
return overheadTypeEncoding(intTp.getWidth());
llvm_unreachable("Unknown overhead type");
}
@@ -169,7 +169,7 @@ PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) {
return PrimaryType::kI16;
if (elemTp.isInteger(8))
return PrimaryType::kI8;
- if (auto complexTp = elemTp.dyn_cast<ComplexType>()) {
+ if (auto complexTp = dyn_cast<ComplexType>(elemTp)) {
auto complexEltTp = complexTp.getElementType();
if (complexEltTp.isF64())
return PrimaryType::kC64;
@@ -205,10 +205,10 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
return value;
// int <=> index
- if (srcTp.isa<IndexType>() || dstTp.isa<IndexType>())
+ if (isa<IndexType>(srcTp) || isa<IndexType>(dstTp))
return builder.create<arith::IndexCastOp>(loc, dstTp, value);
- const auto srcIntTp = srcTp.dyn_cast_or_null<IntegerType>();
+ const auto srcIntTp = dyn_cast_or_null<IntegerType>(srcTp);
const bool isUnsignedCast = srcIntTp ? srcIntTp.isUnsigned() : false;
return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast);
}
@@ -216,7 +216,7 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem,
Value s) {
Value load = builder.create<memref::LoadOp>(loc, mem, s);
- if (!load.getType().isa<IndexType>()) {
+ if (!isa<IndexType>(load.getType())) {
if (load.getType().getIntOrFloatBitWidth() < 64)
load = builder.create<arith::ExtUIOp>(loc, builder.getI64Type(), load);
load =
@@ -226,14 +226,14 @@ Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem,
}
mlir::TypedAttr mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
- if (tp.isa<FloatType>())
+ if (isa<FloatType>(tp))
return builder.getFloatAttr(tp, 1.0);
- if (tp.isa<IndexType>())
+ if (isa<IndexType>(tp))
return builder.getIndexAttr(1);
- if (auto intTp = tp.dyn_cast<IntegerType>())
+ if (auto intTp = dyn_cast<IntegerType>(tp))
return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1));
- if (tp.isa<RankedTensorType, VectorType>()) {
- auto shapedTp = tp.cast<ShapedType>();
+ if (isa<RankedTensorType, VectorType>(tp)) {
+ auto shapedTp = cast<ShapedType>(tp);
if (auto one = getOneAttr(builder, shapedTp.getElementType()))
return DenseElementsAttr::get(shapedTp, one);
}
@@ -244,13 +244,13 @@ Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc,
Value v) {
Type tp = v.getType();
Value zero = constantZero(builder, loc, tp);
- if (tp.isa<FloatType>())
+ if (isa<FloatType>(tp))
return builder.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v,
zero);
if (tp.isIntOrIndex())
return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v,
zero);
- if (tp.dyn_cast<ComplexType>())
+ if (dyn_cast<ComplexType>(tp))
return builder.create<complex::NotEqualOp>(loc, v, zero);
llvm_unreachable("Non-numeric type");
}
@@ -580,12 +580,12 @@ void sparse_tensor::foreachInSparseConstant(
}
// Remap value.
Value val;
- if (attr.getElementType().isa<ComplexType>()) {
- auto valAttr = elems[i].second.cast<ArrayAttr>();
+ if (isa<ComplexType>(attr.getElementType())) {
+ auto valAttr = cast<ArrayAttr>(elems[i].second);
val = builder.create<complex::ConstantOp>(loc, attr.getElementType(),
valAttr);
} else {
- auto valAttr = elems[i].second.cast<TypedAttr>();
+ auto valAttr = cast<TypedAttr>(elems[i].second);
val = builder.create<arith::ConstantOp>(loc, valAttr);
}
assert(val);
@@ -597,7 +597,7 @@ SmallVector<Value> sparse_tensor::loadAll(OpBuilder &builder, Location loc,
size_t size, Value mem,
size_t offsetIdx, Value offsetVal) {
#ifndef NDEBUG
- const auto memTp = mem.getType().cast<MemRefType>();
+ const auto memTp = cast<MemRefType>(mem.getType());
assert(memTp.getRank() == 1);
const DynSize memSh = memTp.getDimSize(0);
assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<DynSize>(size));
@@ -619,7 +619,7 @@ void sparse_tensor::storeAll(OpBuilder &builder, Location loc, Value mem,
ValueRange vs, size_t offsetIdx, Value offsetVal) {
#ifndef NDEBUG
const size_t vsize = vs.size();
- const auto memTp = mem.getType().cast<MemRefType>();
+ const auto memTp = cast<MemRefType>(mem.getType());
assert(memTp.getRank() == 1);
const DynSize memSh = memTp.getDimSize(0);
assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<DynSize>(vsize));
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index e04475ea2e8f..9e762892e864 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -260,7 +260,7 @@ Value reshapeValuesToLevels(OpBuilder &builder, Location loc,
/// `IntegerType`), this also works for `RankedTensorType` and `VectorType`
/// (for which it generates a constant `DenseElementsAttr` of zeros).
inline Value constantZero(OpBuilder &builder, Location loc, Type tp) {
- if (auto ctp = tp.dyn_cast<ComplexType>()) {
+ if (auto ctp = dyn_cast<ComplexType>(tp)) {
auto zeroe = builder.getZeroAttr(ctp.getElementType());
auto zeroa = builder.getArrayAttr({zeroe, zeroe});
return builder.create<complex::ConstantOp>(loc, tp, zeroa);
@@ -271,7 +271,7 @@ inline Value constantZero(OpBuilder &builder, Location loc, Type tp) {
/// Generates a 1-valued constant of the given type. This supports all
/// the same types as `constantZero`.
inline Value constantOne(OpBuilder &builder, Location loc, Type tp) {
- if (auto ctp = tp.dyn_cast<ComplexType>()) {
+ if (auto ctp = dyn_cast<ComplexType>(tp)) {
auto zeroe = builder.getZeroAttr(ctp.getElementType());
auto onee = getOneAttr(builder, ctp.getElementType());
auto zeroa = builder.getArrayAttr({onee, zeroe});
@@ -350,7 +350,7 @@ inline Value constantDimLevelTypeEncoding(OpBuilder &builder, Location loc,
}
inline bool isZeroRankedTensorOrScalar(Type type) {
- auto rtp = type.dyn_cast<RankedTensorType>();
+ auto rtp = dyn_cast<RankedTensorType>(type);
return !rtp || rtp.getRank() == 0;
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 731a1a9e460e..d61e54505678 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -350,7 +350,7 @@ void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc,
// on positions.
for (TensorId t = 0, numTensors = getNumTensors(); t < numTensors; t++) {
const Value tensor = tensors[t];
- const auto rtp = tensor.getType().dyn_cast<RankedTensorType>();
+ const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
if (!rtp)
// Skips only scalar, zero ranked tensor still need to be bufferized and
// (probably) filled with zeros by users.
@@ -432,7 +432,7 @@ void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc,
Type indexType = builder.getIndexType();
Value c0 = constantZero(builder, loc, indexType);
for (TensorId t = 0, e = tensors.size(); t < e; t++) {
- auto rtp = tensors[t].getType().dyn_cast<RankedTensorType>();
+ auto rtp = dyn_cast<RankedTensorType>(tensors[t].getType());
if (!rtp)
continue;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
index 67a3f3d038a1..03715785d284 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
@@ -415,11 +415,11 @@ private:
// check `dstLvl < dstLvlRank` at the top; and only here need to
// assert that `reassoc.size() == dstLvlRank`.
assert(dstLvl < reassoc.size() && "Level is out-of-bounds");
- const auto srcLvls = reassoc[dstLvl].cast<ArrayAttr>();
+ const auto srcLvls = cast<ArrayAttr>(reassoc[dstLvl]);
return llvm::to_vector<2>(
llvm::map_range(srcLvls, [&](Attribute srcLvl) -> Level {
// TODO: replace this with the converter for `LevelAttr`.
- return srcLvl.cast<IntegerAttr>().getValue().getZExtValue();
+ return cast<IntegerAttr>(srcLvl).getValue().getZExtValue();
}));
}
return {dstLvl};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index c99c26b9c98c..bb52e08686fe 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -100,7 +100,7 @@ static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc,
/// completion. Needs to cast the buffer to a unranked buffer.
static Value genHostRegisterMemref(OpBuilder &builder, Location loc,
Value mem) {
- MemRefType memTp = mem.getType().cast<MemRefType>();
+ MemRefType memTp = cast<MemRefType>(mem.getType());
UnrankedMemRefType resTp =
UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0);
Value cast = builder.create<memref::CastOp>(loc, resTp, mem);
@@ -133,7 +133,7 @@ static void genBlockingWait(OpBuilder &builder, Location loc,
/// that feature does not seem to be fully supported yet.
static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem,
Value token) {
- auto tp = mem.getType().cast<ShapedType>();
+ auto tp = cast<ShapedType>(mem.getType());
auto elemTp = tp.getElementType();
auto shape = tp.getShape();
auto memTp = MemRefType::get(shape, elemTp);
@@ -304,7 +304,7 @@ struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
for (OpOperand &o : op->getOpOperands()) {
Value val = o.get();
Block *block;
- if (auto arg = val.dyn_cast<BlockArgument>())
+ if (auto arg = dyn_cast<BlockArgument>(val))
block = arg.getOwner();
else
block = val.getDefiningOp()->getBlock();
@@ -321,7 +321,7 @@ struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
Type tp = val.getType();
if (val.getDefiningOp<arith::ConstantOp>())
constants.push_back(val);
- else if (tp.isa<FloatType>() || tp.isIntOrIndex())
+ else if (isa<FloatType>(tp) || tp.isIntOrIndex())
scalars.push_back(val);
else if (isa<MemRefType>(tp))
buffers.push_back(val);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
index 0c68c4db4fe9..f34ed9779cfd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp
@@ -111,9 +111,9 @@ Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc,
Value metaData = builder.create<LLVM::UndefOp>(loc, structType);
SpecifierStructBuilder md(metaData);
if (!source) {
- auto memSizeArrayType = structType.cast<LLVM::LLVMStructType>()
- .getBody()[kMemSizePosInSpecifier]
- .cast<LLVM::LLVMArrayType>();
+ auto memSizeArrayType =
+ cast<LLVM::LLVMArrayType>(cast<LLVM::LLVMStructType>(structType)
+ .getBody()[kMemSizePosInSpecifier]);
Value zero = constantZero(builder, loc, memSizeArrayType.getElementType());
// Fill memSizes array with zero.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index aebf0542b333..88f79bf3e8d4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -80,7 +80,7 @@ static void genStore(OpBuilder &builder, Location loc, Value val, Value mem,
Value idx) {
idx = genCast(builder, loc, idx, builder.getIndexType());
val = genCast(builder, loc, val,
- mem.getType().cast<ShapedType>().getElementType());
+ cast<ShapedType>(mem.getType()).getElementType());
builder.create<memref::StoreOp>(loc, val, mem, idx);
}
@@ -253,7 +253,7 @@ static void createAllocFields(OpBuilder &builder, Location loc,
case SparseTensorFieldKind::CrdMemRef:
case SparseTensorFieldKind::ValMemRef:
field = createAllocation(
- builder, loc, fType.cast<MemRefType>(),
+ builder, loc, cast<MemRefType>(fType),
(fKind == SparseTensorFieldKind::PosMemRef) ? posHeuristic
: (fKind == SparseTensorFieldKind::CrdMemRef) ? crdHeuristic
: valHeuristic,
@@ -779,7 +779,7 @@ public:
fields.reserve(desc.getNumFields());
// Memcpy on memref fields.
for (auto field : desc.getMemRefFields()) {
- auto memrefTp = field.getType().cast<MemRefType>();
+ auto memrefTp = cast<MemRefType>(field.getType());
auto size = rewriter.create<memref::DimOp>(loc, field, 0);
auto copied =
rewriter.create<memref::AllocOp>(loc, memrefTp, ValueRange{size});
@@ -1128,7 +1128,7 @@ public:
auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource());
SmallVector<Value> fields;
foreachFieldAndTypeInSparseTensor(
- SparseTensorType(op.getResult().getType().cast<RankedTensorType>()),
+ SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
[&rewriter, &fields, srcDesc,
loc](Type fTp, FieldIndex fIdx, SparseTensorFieldKind fKind, Level lvl,
DimLevelType /*dlt*/) -> bool {
@@ -1143,7 +1143,7 @@ public:
// values.
Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0);
auto dstMem = rewriter.create<memref::AllocOp>(
- loc, fTp.cast<MemRefType>(), sz);
+ loc, cast<MemRefType>(fTp), sz);
if (fTp != srcMem.getType()) {
// Converts elements type.
scf::buildLoopNest(
@@ -1397,7 +1397,7 @@ struct SparsePackOpConverter : public OpConversionPattern<PackOp> {
}
assert(field);
- if (auto memrefTp = field.getType().dyn_cast<MemRefType>();
+ if (auto memrefTp = dyn_cast<MemRefType>(field.getType());
memrefTp && memrefTp.getRank() > 1) {
ReassociationIndices reassociation;
for (int i = 0, e = memrefTp.getRank(); i < e; i++)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 8d0c8548097f..906f700cfc47 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -399,7 +399,7 @@ static void genAddEltCall(OpBuilder &builder, Location loc, Type eltType,
/// (which can be either dim- or lvl-coords, depending on context).
static Value genGetNextCall(OpBuilder &builder, Location loc, Value iter,
Value coords, Value elemPtr) {
- Type elemTp = elemPtr.getType().cast<ShapedType>().getElementType();
+ Type elemTp = cast<ShapedType>(elemPtr.getType()).getElementType();
SmallString<10> name{"getNext", primaryTypeFunctionSuffix(elemTp)};
SmallVector<Value, 3> params{iter, coords, elemPtr};
Type i1 = builder.getI1Type();
@@ -1045,7 +1045,7 @@ public:
matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type resTp = op.getType();
- Type posTp = resTp.cast<ShapedType>().getElementType();
+ Type posTp = cast<ShapedType>(resTp).getElementType();
SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
Value lvl = constantIndex(rewriter, op->getLoc(), op.getLevel());
replaceOpWithFuncCall(rewriter, op, name, resTp, {adaptor.getTensor(), lvl},
@@ -1064,7 +1064,7 @@ public:
ConversionPatternRewriter &rewriter) const override {
// TODO: use `SparseTensorType::getCrdType` instead.
Type resType = op.getType();
- const Type crdTp = resType.cast<ShapedType>().getElementType();
+ const Type crdTp = cast<ShapedType>(resType).getElementType();
SmallString<19> name{"sparseCoordinates",
overheadTypeFunctionSuffix(crdTp)};
Location loc = op->getLoc();
@@ -1096,7 +1096,7 @@ public:
LogicalResult
matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto resType = op.getType().cast<ShapedType>();
+ auto resType = cast<ShapedType>(op.getType());
rewriter.replaceOp(op, genValuesCall(rewriter, op.getLoc(), resType,
adaptor.getOperands()));
return success();
@@ -1113,7 +1113,7 @@ public:
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
// Query values array size for the actually stored values size.
- Type eltType = op.getTensor().getType().cast<ShapedType>().getElementType();
+ Type eltType = cast<ShapedType>(op.getTensor().getType()).getElementType();
auto resTp = MemRefType::get({ShapedType::kDynamic}, eltType);
Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands());
rewriter.replaceOpWithNewOp<memref::DimOp>(op, values,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 2a4bbb06eb50..ca27794b64c1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -79,7 +79,7 @@ static bool isSampling(GenericOp op) {
// Helper to detect chain of multiplications that do not involve x.
static bool isMulChain(Value val, Value x) {
- if (auto arg = val.dyn_cast<BlockArgument>())
+ if (auto arg = dyn_cast<BlockArgument>(val))
return arg != x;
if (auto *def = val.getDefiningOp()) {
if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
@@ -105,7 +105,7 @@ static bool isSumOfMul(GenericOp op) {
// Helper to detect direct yield of a zero value.
static bool isZeroYield(GenericOp op) {
auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
- if (auto arg = yieldOp.getOperand(0).dyn_cast<BlockArgument>()) {
+ if (auto arg = dyn_cast<BlockArgument>(yieldOp.getOperand(0))) {
if (arg.getOwner()->getParentOp() == op) {
return isZeroValue(op->getOperand(arg.getArgNumber()));
}
@@ -719,7 +719,7 @@ private:
bool fromSparseConst = false;
if (auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>()) {
- if (constOp.getValue().dyn_cast<SparseElementsAttr>()) {
+ if (dyn_cast<SparseElementsAttr>(constOp.getValue())) {
fromSparseConst = true;
}
}
@@ -972,7 +972,7 @@ public:
// Special-case: for each over a sparse constant uses its own rewriting
// rule.
if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) {
- if (auto attr = constOp.getValue().dyn_cast<SparseElementsAttr>()) {
+ if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) {
return genForeachOnSparseConstant(op, rewriter, attr);
}
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
index 788ad28ee422..a51fcc598ea9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h
@@ -450,7 +450,7 @@ inline Value genTuple(OpBuilder &builder, Location loc,
inline SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) {
auto tuple = getTuple(tensor);
- SparseTensorType stt(tuple.getResultTypes()[0].cast<RankedTensorType>());
+ SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0]));
return SparseTensorDescriptor(stt, tuple.getInputs());
}
@@ -458,7 +458,7 @@ inline MutSparseTensorDescriptor
getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields) {
auto tuple = getTuple(tensor);
fields.assign(tuple.getInputs().begin(), tuple.getInputs().end());
- SparseTensorType stt(tuple.getResultTypes()[0].cast<RankedTensorType>());
+ SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0]));
return MutSparseTensorDescriptor(stt, fields);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index afeabb33fcd7..681ba21dd4a3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -88,9 +88,9 @@ public:
// Overrides method from AffineExprVisitor.
void visitDimExpr(AffineDimExpr expr) {
if (pickedDim == nullptr ||
- pickIterType == iterTypes[expr.getPosition()]
- .cast<linalg::IteratorTypeAttr>()
- .getValue()) {
+ pickIterType ==
+ cast<linalg::IteratorTypeAttr>(iterTypes[expr.getPosition()])
+ .getValue()) {
pickedDim = expr;
}
}
@@ -344,7 +344,7 @@ static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map,
// we can't use `getRankedTensorType`/`getSparseTensorType` here.
// However, we don't need to handle `StorageSpecifierType`, so we
// can use `SparseTensorType` once we guard against non-tensors.
- const auto rtp = tensor.getType().dyn_cast<RankedTensorType>();
+ const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
if (!rtp)
return 0;
const SparseTensorType stt(rtp);
@@ -1243,7 +1243,7 @@ static void genExpand(CodegenEnv &env, OpBuilder &builder, LoopOrd at,
Location loc = op.getLoc();
if (atStart) {
auto dynShape = {ShapedType::kDynamic};
- Type etp = tensor.getType().cast<ShapedType>().getElementType();
+ Type etp = cast<ShapedType>(tensor.getType()).getElementType();
Type t1 = MemRefType::get(dynShape, etp);
Type t2 = MemRefType::get(dynShape, builder.getI1Type());
Type t3 = MemRefType::get(dynShape, builder.getIndexType());
@@ -1833,7 +1833,7 @@ public:
// required for sparse tensor slice rank reducing too.
Level maxLvlRank = 0;
for (auto operand : op.getOperands()) {
- if (auto rtp = operand.getType().dyn_cast<RankedTensorType>()) {
+ if (auto rtp = dyn_cast<RankedTensorType>(operand.getType())) {
maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank());
}
}
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index e8558c1d8d9d..ae31af0cc572 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -1061,8 +1061,8 @@ bool Merger::maybeZero(ExprId e) const {
if (expr.kind == TensorExp::Kind::kInvariant) {
if (auto c = expr.val.getDefiningOp<complex::ConstantOp>()) {
ArrayAttr arrayAttr = c.getValue();
- return arrayAttr[0].cast<FloatAttr>().getValue().isZero() &&
- arrayAttr[1].cast<FloatAttr>().getValue().isZero();
+ return cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
+ cast<FloatAttr>(arrayAttr[1]).getValue().isZero();
}
if (auto c = expr.val.getDefiningOp<arith::ConstantIntOp>())
return c.value() == 0;
@@ -1077,7 +1077,7 @@ Type Merger::inferType(ExprId e, Value src) const {
Type dtp = exp(e).val.getType();
// Inspect source type. For vector types, apply the same
// vectorization to the destination type.
- if (auto vtp = src.getType().dyn_cast<VectorType>())
+ if (auto vtp = dyn_cast<VectorType>(src.getType()))
return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims());
return dtp;
}
@@ -1085,7 +1085,7 @@ Type Merger::inferType(ExprId e, Value src) const {
/// Ensures that sparse compiler can generate code for expression.
static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) {
// Arguments are always admissible.
- if (v.isa<BlockArgument>())
+ if (isa<BlockArgument>(v))
return true;
// Accept index anywhere.
Operation *def = v.getDefiningOp();
@@ -1113,7 +1113,7 @@ static bool isAdmissibleBranch(Operation *op, Region &region) {
}
std::optional<ExprId> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
- if (auto arg = v.dyn_cast<BlockArgument>()) {
+ if (auto arg = dyn_cast<BlockArgument>(v)) {
const TensorId tid = makeTensorId(arg.getArgNumber());
// Any argument of the generic op that is not marked as a scalar
// argument is considered a tensor, indexed by the implicit loop
@@ -1346,8 +1346,8 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
case TensorExp::Kind::kAbsF:
return rewriter.create<math::AbsFOp>(loc, v0);
case TensorExp::Kind::kAbsC: {
- auto type = v0.getType().cast<ComplexType>();
- auto eltType = type.getElementType().cast<FloatType>();
+ auto type = cast<ComplexType>(v0.getType());
+ auto eltType = cast<FloatType>(type.getElementType());
return rewriter.create<complex::AbsOp>(loc, eltType, v0);
}
case TensorExp::Kind::kAbsI:
@@ -1407,13 +1407,13 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
case TensorExp::Kind::kTruncI:
return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
case TensorExp::Kind::kCIm: {
- auto type = v0.getType().cast<ComplexType>();
- auto eltType = type.getElementType().cast<FloatType>();
+ auto type = cast<ComplexType>(v0.getType());
+ auto eltType = cast<FloatType>(type.getElementType());
return rewriter.create<complex::ImOp>(loc, eltType, v0);
}
case TensorExp::Kind::kCRe: {
- auto type = v0.getType().cast<ComplexType>();
- auto eltType = type.getElementType().cast<FloatType>();
+ auto type = cast<ComplexType>(v0.getType());
+ auto eltType = cast<FloatType>(type.getElementType());
return rewriter.create<complex::ReOp>(loc, eltType, v0);
}
case TensorExp::Kind::kBitCast:
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 57e5df463343..d93d88630fd8 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -60,20 +60,20 @@ struct CastOpInterface
// type in case the input is an unranked tensor type.
// Case 1: Casting an unranked tensor
- if (castOp.getSource().getType().isa<UnrankedTensorType>()) {
+ if (isa<UnrankedTensorType>(castOp.getSource().getType())) {
// When casting to a ranked tensor, we cannot infer any static offset or
// strides from the source. Assume fully dynamic.
return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace);
}
// Case 2: Casting to an unranked tensor type
- if (castOp.getType().isa<UnrankedTensorType>()) {
+ if (isa<UnrankedTensorType>(castOp.getType())) {
return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace);
}
// Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not
// change.
- auto rankedResultType = castOp.getType().cast<RankedTensorType>();
+ auto rankedResultType = cast<RankedTensorType>(castOp.getType());
return MemRefType::get(
rankedResultType.getShape(), rankedResultType.getElementType(),
maybeSrcBufferType->cast<MemRefType>().getLayout(), memorySpace);
@@ -158,7 +158,7 @@ struct CollapseShapeOpInterface
if (failed(maybeBuffer))
return failure();
Value buffer = *maybeBuffer;
- auto bufferType = buffer.getType().cast<MemRefType>();
+ auto bufferType = cast<MemRefType>(buffer.getType());
if (tensorResultType.getRank() == 0) {
// 0-d collapses must go through a different op builder.
@@ -383,11 +383,9 @@ struct ExtractSliceOpInterface
SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
- return memref::SubViewOp::inferRankReducedResultType(
- extractSliceOp.getType().getShape(),
- srcMemrefType->cast<MemRefType>(), mixedOffsets, mixedSizes,
- mixedStrides)
- .cast<BaseMemRefType>();
+ return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
+ extractSliceOp.getType().getShape(), srcMemrefType->cast<MemRefType>(),
+ mixedOffsets, mixedSizes, mixedStrides));
}
};
@@ -459,7 +457,7 @@ struct FromElementsOpInterface
auto fromElementsOp = cast<tensor::FromElementsOp>(op);
// Should the buffer be deallocated?
bool dealloc = shouldDeallocateOpResult(
- fromElementsOp.getResult().cast<OpResult>(), options);
+ cast<OpResult>(fromElementsOp.getResult()), options);
// TODO: Implement memory space for this op.
if (options.defaultMemorySpace != Attribute())
@@ -467,7 +465,7 @@ struct FromElementsOpInterface
// Allocate a buffer for the result.
Location loc = op->getLoc();
- auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
+ auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
auto shape = tensorType.getShape();
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
FailureOr<Value> tensorAlloc =
@@ -540,7 +538,7 @@ static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc,
ValueRange dynamicSizes,
Region &generateBody) {
assert(generateBody.hasOneBlock() && "expected body with single block");
- auto tensorType = tensorDestination.getType().cast<RankedTensorType>();
+ auto tensorType = cast<RankedTensorType>(tensorDestination.getType());
assert(generateBody.getNumArguments() == tensorType.getRank() &&
"rank mismatch");
@@ -579,7 +577,7 @@ struct GenerateOpInterface
auto generateOp = cast<tensor::GenerateOp>(op);
// Should the buffer be deallocated?
bool dealloc = shouldDeallocateOpResult(
- generateOp.getResult().cast<OpResult>(), options);
+ cast<OpResult>(generateOp.getResult()), options);
// TODO: Implement memory space for this op.
if (options.defaultMemorySpace != Attribute())
@@ -800,12 +798,11 @@ struct InsertSliceOpInterface
return failure();
// Take a subview of the destination buffer.
- auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
+ auto dstMemrefType = cast<MemRefType>(dstMemref->getType());
auto subviewMemRefType =
- memref::SubViewOp::inferRankReducedResultType(
+ cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
insertSliceOp.getSourceType().getShape(), dstMemrefType,
- mixedOffsets, mixedSizes, mixedStrides)
- .cast<MemRefType>();
+ mixedOffsets, mixedSizes, mixedStrides));
Value subView = rewriter.create<memref::SubViewOp>(
loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes,
mixedStrides);
@@ -900,7 +897,7 @@ struct PadOpInterface
// Should the buffer be deallocated?
bool dealloc =
- shouldDeallocateOpResult(padOp.getResult().cast<OpResult>(), options);
+ shouldDeallocateOpResult(cast<OpResult>(padOp.getResult()), options);
// Allocate a buffer for the padded result.
FailureOr<Value> tensorAlloc =
allocateTensorForShapedValue(rewriter, loc, padOp.getResult(),
@@ -992,7 +989,7 @@ struct ReshapeOpInterface
return failure();
auto resultMemRefType = getMemRefType(
reshapeOp.getResult(), options, /*layout=*/{},
- srcBuffer->getType().cast<BaseMemRefType>().getMemorySpace());
+ cast<BaseMemRefType>(srcBuffer->getType()).getMemorySpace());
replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);
return success();
@@ -1039,14 +1036,13 @@ struct ParallelInsertSliceOpInterface
return failure();
// Take a subview of the destination buffer.
- auto destBufferType = destBuffer->getType().cast<MemRefType>();
+ auto destBufferType = cast<MemRefType>(destBuffer->getType());
auto subviewMemRefType =
- memref::SubViewOp::inferRankReducedResultType(
+ cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
parallelInsertSliceOp.getSourceType().getShape(), destBufferType,
parallelInsertSliceOp.getMixedOffsets(),
parallelInsertSliceOp.getMixedSizes(),
- parallelInsertSliceOp.getMixedStrides())
- .cast<MemRefType>();
+ parallelInsertSliceOp.getMixedStrides()));
Value subview = rewriter.create<memref::SubViewOp>(
parallelInsertSliceOp.getLoc(), subviewMemRefType, *destBuffer,
parallelInsertSliceOp.getMixedOffsets(),
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp
index b5e75e081886..968d68e143fe 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp
@@ -29,7 +29,7 @@ using namespace mlir::tensor;
/// Get the dimension size of a value of RankedTensor type at the
static OpFoldResult getShapeDimSize(OpBuilder &b, Location loc,
Value rankedTensor, int64_t dimIdx) {
- RankedTensorType tensorType = rankedTensor.getType().cast<RankedTensorType>();
+ RankedTensorType tensorType = cast<RankedTensorType>(rankedTensor.getType());
if (!tensorType.isDynamicDim(dimIdx)) {
return b.getIndexAttr(tensorType.getDimSize(dimIdx));
}
@@ -41,7 +41,7 @@ static OpFoldResult getShapeDimSize(OpBuilder &b, Location loc,
static SmallVector<OpFoldResult> getShapeDimSizes(OpBuilder &b, Location loc,
Value rankedTensor) {
SmallVector<OpFoldResult> dimSizes;
- RankedTensorType tensorType = rankedTensor.getType().cast<RankedTensorType>();
+ RankedTensorType tensorType = cast<RankedTensorType>(rankedTensor.getType());
for (unsigned i = 0; i < tensorType.getRank(); i++)
dimSizes.push_back(getShapeDimSize(b, loc, rankedTensor, i));
return dimSizes;
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 71dddd136379..4ecb800caab4 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -44,7 +44,7 @@ PadOp mlir::tensor::createPadHighOp(RankedTensorType type, Value source,
SmallVector<Value> mlir::tensor::createDynamicDimValues(OpBuilder &b,
Location loc,
Value rankedTensor) {
- auto tensorTy = rankedTensor.getType().cast<RankedTensorType>();
+ auto tensorTy = cast<RankedTensorType>(rankedTensor.getType());
SmallVector<Value> dynamicDims;
for (const auto &en : llvm::enumerate(tensorTy.getShape())) {
if (en.value() == ShapedType::kDynamic)
@@ -57,7 +57,7 @@ SmallVector<Value> mlir::tensor::createDynamicDimValues(OpBuilder &b,
FailureOr<OpFoldResult> mlir::tensor::createDimValue(OpBuilder &b, Location loc,
Value rankedTensor,
int64_t dim) {
- auto tensorTy = rankedTensor.getType().dyn_cast<RankedTensorType>();
+ auto tensorTy = dyn_cast<RankedTensorType>(rankedTensor.getType());
if (!tensorTy)
return failure();
auto shape = tensorTy.getShape();
@@ -70,7 +70,7 @@ FailureOr<OpFoldResult> mlir::tensor::createDimValue(OpBuilder &b, Location loc,
SmallVector<OpFoldResult>
mlir::tensor::createDimValues(OpBuilder &b, Location loc, Value rankedTensor) {
- auto tensorTy = rankedTensor.getType().cast<RankedTensorType>();
+ auto tensorTy = cast<RankedTensorType>(rankedTensor.getType());
SmallVector<OpFoldResult> dims;
for (const auto &en : llvm::enumerate(tensorTy.getShape())) {
if (ShapedType::isDynamic(en.value())) {
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
index 7b4733864972..44f64f76e9b0 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
@@ -34,9 +34,9 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
PatternRewriter &rewriter) const override {
Value input = op.getInput();
Value weight = op.getWeight();
- ShapedType inputType = input.getType().cast<ShapedType>();
- ShapedType weightType = weight.getType().cast<ShapedType>();
- ShapedType resultType = op.getType().cast<ShapedType>();
+ ShapedType inputType = cast<ShapedType>(input.getType());
+ ShapedType weightType = cast<ShapedType>(weight.getType());
+ ShapedType resultType = cast<ShapedType>(op.getType());
auto numDynamic =
llvm::count_if(inputType.getShape(), ShapedType::isDynamic);
@@ -66,7 +66,7 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
auto quantizationInfo = op.getQuantizationInfo();
int64_t iZp = quantizationInfo->getInputZp();
- if (!validIntegerRange(inputETy.cast<IntegerType>(), iZp))
+ if (!validIntegerRange(cast<IntegerType>(inputETy), iZp))
return rewriter.notifyMatchFailure(
op, "tosa.conv op quantization has zp outside of input range");
@@ -116,7 +116,7 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
weightShape[3]};
auto revisedWeightShapeType = RankedTensorType::get(
revisedWeightShape,
- weight.getType().dyn_cast<RankedTensorType>().getElementType());
+ dyn_cast<RankedTensorType>(weight.getType()).getElementType());
auto reshapedWeight = rewriter
.create<tosa::ReshapeOp>(
op.getLoc(), revisedWeightShapeType, weight,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index 81ec7fd66379..488e46d1339a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -28,9 +28,9 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
PatternRewriter &rewriter) const override {
Value input = op.getInput();
Value weight = op.getWeight();
- ShapedType inputType = input.getType().cast<ShapedType>();
- ShapedType weightType = weight.getType().cast<ShapedType>();
- ShapedType resultType = op.getOutput().getType().cast<ShapedType>();
+ ShapedType inputType = cast<ShapedType>(input.getType());
+ ShapedType weightType = cast<ShapedType>(weight.getType());
+ ShapedType resultType = cast<ShapedType>(op.getOutput().getType());
if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
resultType.hasStaticShape())) {
@@ -52,7 +52,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1};
inputType = RankedTensorType::get(
revisedInputShape,
- input.getType().dyn_cast<RankedTensorType>().getElementType());
+ dyn_cast<RankedTensorType>(input.getType()).getElementType());
input = rewriter
.create<tosa::ReshapeOp>(
op.getLoc(), inputType, input,
@@ -76,7 +76,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
auto applyZp = [&](Value val, int64_t zp) -> Value {
if (zp == 0)
return val;
- auto ety = val.getType().cast<ShapedType>().getElementType();
+ auto ety = cast<ShapedType>(val.getType()).getElementType();
auto zpTy = RankedTensorType::get({}, ety);
auto zpAttr =
DenseElementsAttr::get(zpTy, rewriter.getIntegerAttr(ety, zp));
@@ -126,17 +126,17 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
inputType.getDimSize(2), inputType.getDimSize(3), weightShape[3]};
auto mulShapeType = RankedTensorType::get(
mulShape,
- weight.getType().dyn_cast<RankedTensorType>().getElementType());
+ dyn_cast<RankedTensorType>(weight.getType()).getElementType());
Value mulValue = rewriter
.create<tosa::MulOp>(op.getLoc(), mulShapeType, input,
weight, /*shift=*/0)
.getResult();
// Reshape output to [N, H, W, C * M].
- auto outputShape = op.getOutput().getType().cast<ShapedType>().getShape();
+ auto outputShape = cast<ShapedType>(op.getOutput().getType()).getShape();
auto outputShapeType = RankedTensorType::get(
outputShape,
- input.getType().dyn_cast<RankedTensorType>().getElementType());
+ dyn_cast<RankedTensorType>(input.getType()).getElementType());
auto outputValue = rewriter.create<tosa::ReshapeOp>(
op.getLoc(), outputShapeType, mulValue,
rewriter.getDenseI64ArrayAttr(outputShape));
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 74533defd055..87563c1761a8 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -56,7 +56,7 @@ TosaOp createOpAndInfer(PatternRewriter &rewriter, Location loc, Type resultTy,
// Compute the knowledge based on the inferred type.
auto inferredKnowledge =
mlir::tosa::ValueKnowledge::getPessimisticValueState();
- inferredKnowledge.dtype = resultTy.cast<ShapedType>().getElementType();
+ inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
inferredKnowledge.hasRank = predictedShape.hasRank();
if (predictedShape.hasRank()) {
for (auto dim : predictedShape.getDims()) {
@@ -83,10 +83,10 @@ public:
Value weight = op->getOperand(1);
Value bias = op->getOperand(2);
- ShapedType inputTy = input.getType().cast<ShapedType>();
- ShapedType weightTy = weight.getType().cast<ShapedType>();
- ShapedType biasTy = bias.getType().cast<ShapedType>();
- ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
+ ShapedType inputTy = cast<ShapedType>(input.getType());
+ ShapedType weightTy = cast<ShapedType>(weight.getType());
+ ShapedType biasTy = cast<ShapedType>(bias.getType());
+ ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
llvm::ArrayRef<int64_t> stride = op.getStride();
llvm::ArrayRef<int64_t> pad = op.getOutPad();
@@ -146,10 +146,10 @@ public:
Value weight = op->getOperand(1);
Value bias = op->getOperand(2);
- ShapedType inputTy = input.getType().cast<ShapedType>();
- ShapedType weightTy = weight.getType().cast<ShapedType>();
- ShapedType biasTy = bias.getType().cast<ShapedType>();
- ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
+ ShapedType inputTy = cast<ShapedType>(input.getType());
+ ShapedType weightTy = cast<ShapedType>(weight.getType());
+ ShapedType biasTy = cast<ShapedType>(bias.getType());
+ ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
Type inputETy = inputTy.getElementType();
Type weightETy = weightTy.getElementType();
@@ -202,7 +202,7 @@ public:
weight, weightPaddingVal);
}
- weightTy = weight.getType().cast<ShapedType>();
+ weightTy = cast<ShapedType>(weight.getType());
weightHeight = weightTy.getDimSize(1);
weightWidth = weightTy.getDimSize(2);
@@ -231,7 +231,7 @@ public:
weight = createOpAndInfer<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
rewriter.getDenseI64ArrayAttr(weightReshapeDims1));
- ShapedType restridedWeightTy = weight.getType().cast<ShapedType>();
+ ShapedType restridedWeightTy = cast<ShapedType>(weight.getType());
weight = createOpAndInfer<tosa::ReverseOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
@@ -297,7 +297,7 @@ public:
}
// Factor the resulting width / height.
- ShapedType convTy = conv2d.getType().cast<ShapedType>();
+ ShapedType convTy = cast<ShapedType>(conv2d.getType());
Type convETy = convTy.getElementType();
int64_t convHeight = convTy.getDimSize(1);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp
index 9e2102ee1d0a..302e2793f0a3 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp
@@ -72,7 +72,7 @@ DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
auto baseType = inputType.getElementType();
// Handle possible integer types
- if (auto intType = baseType.dyn_cast<IntegerType>()) {
+ if (auto intType = dyn_cast<IntegerType>(baseType)) {
switch (intType.getWidth()) {
case 1:
return transposeType<bool>(attr, inputType, outputType, permValues);
@@ -102,7 +102,7 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const override {
- auto outputType = op.getType().cast<ShapedType>();
+ auto outputType = cast<ShapedType>(op.getType());
// TOSA supports quantized types.
if (!outputType.getElementType().isIntOrIndexOrFloat())
return failure();
@@ -122,7 +122,7 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
permAttr.getValues<APInt>(),
[](const APInt &val) { return val.getSExtValue(); }));
- auto inputType = op.getInput1().getType().cast<ShapedType>();
+ auto inputType = cast<ShapedType>(op.getInput1().getType());
auto resultAttr = transpose(inputValues, inputType, outputType, permValues);
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index 0c03cecf61bc..3e2da9df3f94 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -54,7 +54,7 @@ void propagateShapesToTosaIf(
for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
auto inferredTy = shapesStorage[op.getOperand(i)];
auto blockArg = frontBlock.getArgument(i - 1);
- auto oldType = blockArg.getType().cast<ShapedType>();
+ auto oldType = cast<ShapedType>(blockArg.getType());
if (inferredTy.hasRank()) {
Type newType = oldType.clone(inferredTy.getDims());
@@ -89,7 +89,7 @@ void propagateShapesToTosaWhile(
// loop body / condition for tosa.while.
llvm::SmallVector<Type> argTypes;
for (auto operand : op.getOperands()) {
- auto operandTy = operand.getType().cast<ShapedType>();
+ auto operandTy = cast<ShapedType>(operand.getType());
auto shapedTypeComponent = shapesStorage[operand];
if (shapedTypeComponent.hasRank()) {
auto newTy = operandTy.clone(shapedTypeComponent.getDims());
@@ -188,7 +188,7 @@ void propagateShapesToTosaWhile(
void propagateShapesInRegion(Region &region) {
DenseMap<Value, ShapedTypeComponents> shapesStorage;
auto setShapes = [&](Value val, Type t) {
- if (auto st = t.dyn_cast<ShapedType>())
+ if (auto st = dyn_cast<ShapedType>(t))
shapesStorage[val] = st;
else
shapesStorage[val] = t;
@@ -247,8 +247,7 @@ void propagateShapesInRegion(Region &region) {
// Compute the knowledge based on the inferred type.
auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
- inferredKnowledge.dtype =
- resultTy.cast<ShapedType>().getElementType();
+ inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
inferredKnowledge.hasRank = predictedShape.hasRank();
if (predictedShape.hasRank()) {
for (auto dim : predictedShape.getDims()) {
@@ -274,7 +273,7 @@ void propagateShapesInRegion(Region &region) {
for (auto it : shapesStorage) {
auto result = it.second;
if (result.hasRank()) {
- Type t = it.first.getType().cast<ShapedType>().clone(result.getDims());
+ Type t = cast<ShapedType>(it.first.getType()).clone(result.getDims());
it.first.setType(t);
}
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
index b18e3b4bd277..bcfcbbbbcee6 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
@@ -82,8 +82,8 @@ static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
Location loc,
RankedTensorType outputType,
Value &input1, Value &input2) {
- auto input1Ty = input1.getType().dyn_cast<RankedTensorType>();
- auto input2Ty = input2.getType().dyn_cast<RankedTensorType>();
+ auto input1Ty = dyn_cast<RankedTensorType>(input1.getType());
+ auto input2Ty = dyn_cast<RankedTensorType>(input2.getType());
if (!input1Ty || !input2Ty) {
return rewriter.notifyMatchFailure(loc, "input not a ranked tensor");
@@ -106,9 +106,9 @@ static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
}
ArrayRef<int64_t> higherRankShape =
- higherTensorValue.getType().cast<RankedTensorType>().getShape();
+ cast<RankedTensorType>(higherTensorValue.getType()).getShape();
ArrayRef<int64_t> lowerRankShape =
- lowerTensorValue.getType().cast<RankedTensorType>().getShape();
+ cast<RankedTensorType>(lowerTensorValue.getType()).getShape();
SmallVector<int64_t, 4> reshapeOutputShape;
@@ -116,7 +116,7 @@ static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
.failed())
return rewriter.notifyMatchFailure(loc, "fail to compute a reshape type");
- auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
+ auto reshapeInputType = cast<RankedTensorType>(lowerTensorValue.getType());
auto reshapeOutputType = RankedTensorType::get(
ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
@@ -155,7 +155,7 @@ struct ConvertTosaOp : public OpRewritePattern<OpTy> {
Value input2 = tosaBinaryOp.getInput2();
Value output = tosaBinaryOp.getResult();
- auto outputType = output.getType().dyn_cast<RankedTensorType>();
+ auto outputType = dyn_cast<RankedTensorType>(output.getType());
if (!outputType)
return failure();
@@ -183,7 +183,7 @@ struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
Value input2 = tosaBinaryOp.getInput2();
int32_t shift = tosaBinaryOp.getShift();
Value output = tosaBinaryOp.getResult();
- auto outputType = output.getType().dyn_cast<RankedTensorType>();
+ auto outputType = dyn_cast<RankedTensorType>(output.getType());
if (!outputType)
return failure();
@@ -214,7 +214,7 @@ struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
Value input2 = tosaBinaryOp.getInput2();
int32_t round = tosaBinaryOp.getRound();
Value output = tosaBinaryOp.getResult();
- auto outputType = output.getType().dyn_cast<RankedTensorType>();
+ auto outputType = dyn_cast<RankedTensorType>(output.getType());
if (!outputType)
return failure();
@@ -242,7 +242,7 @@ struct ConvertTosaOp<tosa::SelectOp> : public OpRewritePattern<tosa::SelectOp> {
Value input3 = tosaOp.getOnFalse();
Value output = tosaOp.getResult();
- auto outputType = output.getType().dyn_cast<RankedTensorType>();
+ auto outputType = dyn_cast<RankedTensorType>(output.getType());
if (!outputType)
return rewriter.notifyMatchFailure(tosaOp, "output not a ranked tensor");
@@ -265,9 +265,9 @@ struct ConvertTosaOp<tosa::SelectOp> : public OpRewritePattern<tosa::SelectOp> {
tosaOp,
"cannot rewrite as the rank of all operands is already aligned");
- int32_t result1Rank = input1.getType().cast<RankedTensorType>().getRank();
- int32_t result2Rank = input2.getType().cast<RankedTensorType>().getRank();
- int32_t result3Rank = input3.getType().cast<RankedTensorType>().getRank();
+ int32_t result1Rank = cast<RankedTensorType>(input1.getType()).getRank();
+ int32_t result2Rank = cast<RankedTensorType>(input2.getType()).getRank();
+ int32_t result3Rank = cast<RankedTensorType>(input3.getType()).getRank();
if ((result1Rank != result2Rank) || (result2Rank != result3Rank))
return rewriter.notifyMatchFailure(
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 4cb727b00ca0..5605080384bd 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -106,7 +106,7 @@ void TosaValidation::runOnOperation() {
getOperation().walk([&](Operation *op) {
for (Value operand : op->getOperands()) {
if ((profileType == TosaProfileEnum::BaseInference) &&
- getElementTypeOrSelf(operand).isa<FloatType>()) {
+ isa<FloatType>(getElementTypeOrSelf(operand))) {
return signalPassFailure();
}
if (getElementTypeOrSelf(operand).isF64()) {
diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
index 0b5fc451115c..1c4ae1f27319 100644
--- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
@@ -116,16 +116,16 @@ ConvOpQuantizationAttr
mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input,
Value weight) {
- auto inputType = input.getType().dyn_cast<ShapedType>();
- auto weightType = weight.getType().dyn_cast<ShapedType>();
+ auto inputType = dyn_cast<ShapedType>(input.getType());
+ auto weightType = dyn_cast<ShapedType>(weight.getType());
if (!inputType || !weightType)
return nullptr;
auto inputQType = GET_UQTYPE(inputType);
auto weightPerTensorQType = GET_UQTYPE(weightType);
- auto weightPerAxisQType = weightType.getElementType()
- .dyn_cast<quant::UniformQuantizedPerAxisType>();
+ auto weightPerAxisQType =
+ dyn_cast<quant::UniformQuantizedPerAxisType>(weightType.getElementType());
// Weights must be either per-tensor quantized or per-axis quantized.
assert(!((bool)weightPerTensorQType && (bool)weightPerAxisQType) &&
@@ -160,8 +160,8 @@ MatMulOpQuantizationAttr
mlir::tosa::buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a,
Value b) {
- auto aType = a.getType().dyn_cast<ShapedType>();
- auto bType = b.getType().dyn_cast<ShapedType>();
+ auto aType = dyn_cast<ShapedType>(a.getType());
+ auto bType = dyn_cast<ShapedType>(b.getType());
if (!aType || !bType)
return nullptr;
@@ -189,8 +189,8 @@ UnaryOpQuantizationAttr
mlir::tosa::buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input,
Type outputRawType) {
- auto inputType = input.getType().dyn_cast<ShapedType>();
- auto outputType = outputRawType.dyn_cast<ShapedType>();
+ auto inputType = dyn_cast<ShapedType>(input.getType());
+ auto outputType = dyn_cast<ShapedType>(outputRawType);
if (!inputType || !outputType)
return nullptr;
@@ -215,7 +215,7 @@ mlir::tosa::buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input,
PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder,
Value input) {
- auto inputType = input.getType().dyn_cast<ShapedType>();
+ auto inputType = dyn_cast<ShapedType>(input.getType());
if (!inputType)
return nullptr;
@@ -235,8 +235,8 @@ PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder,
Type mlir::tosa::buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType,
Value input, Value weight) {
- auto inputType = input.getType().dyn_cast<ShapedType>();
- auto weightType = weight.getType().dyn_cast<ShapedType>();
+ auto inputType = dyn_cast<ShapedType>(input.getType());
+ auto weightType = dyn_cast<ShapedType>(weight.getType());
assert(inputType && weightType &&
"Could not extract input or weight tensors from Conv op");
@@ -250,7 +250,7 @@ Type mlir::tosa::buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType,
unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
unsigned weightBits = weightQType.getStorageTypeIntegralWidth();
- auto outputShapedType = outputType.dyn_cast<ShapedType>();
+ auto outputShapedType = dyn_cast<ShapedType>(outputType);
assert(outputShapedType &&
"Could not extract output shape type from Conv op");
@@ -274,8 +274,8 @@ Type mlir::tosa::buildQTypeFromMinMax(OpBuilder builder, Type inputDType,
auto convfunc =
quant::ExpressedToQuantizedConverter::forInputType(inputDType);
- auto minElems = minAttr.dyn_cast<DenseFPElementsAttr>();
- auto maxElems = maxAttr.dyn_cast<DenseFPElementsAttr>();
+ auto minElems = dyn_cast<DenseFPElementsAttr>(minAttr);
+ auto maxElems = dyn_cast<DenseFPElementsAttr>(maxAttr);
SmallVector<double, 2> min, max;
@@ -291,12 +291,12 @@ Type mlir::tosa::buildQTypeFromMinMax(OpBuilder builder, Type inputDType,
for (auto i : maxElems)
max.push_back(FloatAttr::getValueAsDouble(i));
} else { // Just a single FP value.
- auto minVal = minAttr.dyn_cast<FloatAttr>();
+ auto minVal = dyn_cast<FloatAttr>(minAttr);
if (minVal)
min.push_back(minVal.getValueAsDouble());
else
return {};
- auto maxVal = maxAttr.dyn_cast<FloatAttr>();
+ auto maxVal = dyn_cast<FloatAttr>(maxAttr);
if (maxVal)
max.push_back(maxVal.getValueAsDouble());
else
@@ -309,7 +309,7 @@ Type mlir::tosa::buildQTypeFromMinMax(OpBuilder builder, Type inputDType,
builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0],
narrowRange.getValue(), convfunc.expressedType, isSigned);
} else if (min.size() > 1) { // Per-axis quant on filterQuantDim.
- auto shape = inputDType.dyn_cast<ShapedType>();
+ auto shape = dyn_cast<ShapedType>(inputDType);
if (!shape)
return {};
if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) {
diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
index a1a30325258c..2ae67dcfd0be 100644
--- a/mlir/lib/Dialect/Traits.cpp
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -116,7 +116,7 @@ bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
/// Returns the shape of the given type. Scalars will be considered as having a
/// shape with zero dimensions.
static ArrayRef<int64_t> getShape(Type type) {
- if (auto sType = type.dyn_cast<ShapedType>())
+ if (auto sType = dyn_cast<ShapedType>(type))
return sType.getShape();
return {};
}
@@ -142,8 +142,8 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
// If one of the types is unranked tensor, then the other type shouldn't be
// vector and the result should have unranked tensor type.
- if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) {
- if (type1.isa<VectorType>() || type2.isa<VectorType>())
+ if (isa<UnrankedTensorType>(type1) || isa<UnrankedTensorType>(type2)) {
+ if (isa<VectorType>(type1) || isa<VectorType>(type2))
return {};
return UnrankedTensorType::get(elementType);
}
@@ -151,7 +151,7 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
// Returns the type kind if the given type is a vector or ranked tensor type.
// Returns std::nullopt otherwise.
auto getCompositeTypeKind = [](Type type) -> std::optional<TypeID> {
- if (type.isa<VectorType, RankedTensorType>())
+ if (isa<VectorType, RankedTensorType>(type))
return type.getTypeID();
return std::nullopt;
};
@@ -189,8 +189,8 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
template <typename iterator_range>
static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) {
return std::make_tuple(
- llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); }),
- llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }));
+ llvm::any_of(types, [](Type t) { return isa<TensorType>(t); }),
+ llvm::any_of(types, [](Type t) { return isa<VectorType>(t); }));
}
static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred,
@@ -242,7 +242,7 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
return op->emitError("cannot broadcast vector with tensor");
auto rankedOperands = make_filter_range(
- op->getOperandTypes(), [](Type t) { return t.isa<RankedTensorType>(); });
+ op->getOperandTypes(), [](Type t) { return isa<RankedTensorType>(t); });
// If all operands are unranked, then all result shapes are possible.
if (rankedOperands.empty())
@@ -261,7 +261,7 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
}
auto rankedResults = make_filter_range(
- op->getResultTypes(), [](Type t) { return t.isa<RankedTensorType>(); });
+ op->getResultTypes(), [](Type t) { return isa<RankedTensorType>(t); });
// If all of the results are unranked then no further verification.
if (rankedResults.empty())
diff --git a/mlir/lib/Dialect/Transform/Transforms/CheckUses.cpp b/mlir/lib/Dialect/Transform/Transforms/CheckUses.cpp
index 05fba01d689c..45fa644f42ec 100644
--- a/mlir/lib/Dialect/Transform/Transforms/CheckUses.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/CheckUses.cpp
@@ -148,14 +148,14 @@ public:
// TODO: when this ported to the dataflow analysis infra, we should have
// proper support for region-based control flow.
Operation *valueSource =
- operand.get().isa<OpResult>()
+ isa<OpResult>(operand.get())
? operand.get().getDefiningOp()
: operand.get().getParentBlock()->getParentOp();
auto iface = cast<MemoryEffectOpInterface>(valueSource);
SmallVector<MemoryEffects::EffectInstance> instances;
iface.getEffectsOnResource(transform::TransformMappingResource::get(),
instances);
- assert((operand.get().isa<BlockArgument>() ||
+ assert((isa<BlockArgument>(operand.get()) ||
hasEffect<MemoryEffects::Allocate>(instances, operand.get())) &&
"expected the op defining the value to have an allocation effect "
"on it");
@@ -182,7 +182,7 @@ public:
// value is defined in the middle of the block, i.e., is not a block
// argument.
bool isOutermost = ancestor == ancestors.front();
- bool isFromBlockPartial = isOutermost && operand.get().isa<OpResult>();
+ bool isFromBlockPartial = isOutermost && isa<OpResult>(operand.get());
// Check if the value may be freed by operations between its definition
// (allocation) point in its block and the terminator of the block or the
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 94fa2d3de22f..853889269d0f 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -162,7 +162,7 @@ ArrayAttr mlir::getReassociationIndicesAttribute(
SmallVector<Attribute, 4> reassociationAttr =
llvm::to_vector<4>(llvm::map_range(
reassociation, [&](const ReassociationIndices &indices) -> Attribute {
- return b.getI64ArrayAttr(indices).cast<Attribute>();
+ return cast<Attribute>(b.getI64ArrayAttr(indices));
}));
return b.getArrayAttr(reassociationAttr);
}
@@ -267,7 +267,7 @@ LogicalResult mlir::reshapeLikeShapesAreCompatible(
}
bool mlir::hasNonIdentityLayout(Type type) {
- if (auto memrefType = type.dyn_cast<MemRefType>())
+ if (auto memrefType = dyn_cast<MemRefType>(type))
return !memrefType.getLayout().isIdentity();
return false;
}
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 45edd5f89ffe..09137d3336cc 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -19,7 +19,7 @@ bool isZeroIndex(OpFoldResult v) {
if (!v)
return false;
if (auto attr = v.dyn_cast<Attribute>()) {
- IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
+ IntegerAttr intAttr = dyn_cast<IntegerAttr>(attr);
return intAttr && intAttr.getValue().isZero();
}
if (auto cst = v.get<Value>().getDefiningOp<arith::ConstantIndexOp>())
@@ -53,7 +53,7 @@ void dispatchIndexOpFoldResult(OpFoldResult ofr,
SmallVectorImpl<int64_t> &staticVec) {
auto v = ofr.dyn_cast<Value>();
if (!v) {
- APInt apInt = ofr.get<Attribute>().cast<IntegerAttr>().getValue();
+ APInt apInt = cast<IntegerAttr>(ofr.get<Attribute>()).getValue();
staticVec.push_back(apInt.getSExtValue());
return;
}
@@ -71,8 +71,8 @@ void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
return llvm::to_vector<4>(
- llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
- return a.cast<IntegerAttr>().getInt();
+ llvm::map_range(cast<ArrayAttr>(attr), [](Attribute a) -> int64_t {
+ return cast<IntegerAttr>(a).getInt();
}));
}
@@ -124,7 +124,7 @@ std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
}
// Case 2: Check for IntegerAttr.
Attribute attr = ofr.dyn_cast<Attribute>();
- if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
+ if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
return intAttr.getValue().getSExtValue();
return std::nullopt;
}
@@ -184,7 +184,7 @@ decomposeMixedValues(Builder &b,
SmallVector<Value> dynamicValues;
for (const auto &it : mixedValues) {
if (it.is<Attribute>()) {
- staticValues.push_back(it.get<Attribute>().cast<IntegerAttr>().getInt());
+ staticValues.push_back(cast<IntegerAttr>(it.get<Attribute>()).getInt());
} else {
staticValues.push_back(ShapedType::kDynamic);
dynamicValues.push_back(it.get<Value>());
diff --git a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
index aed39f864400..a2977901f475 100644
--- a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp
@@ -21,9 +21,9 @@ bool mlir::isRowMajorMatmul(ArrayAttr indexingMaps) {
if (indexingMaps.size() != 3)
return false;
- auto map0 = indexingMaps[0].cast<AffineMapAttr>().getValue();
- auto map1 = indexingMaps[1].cast<AffineMapAttr>().getValue();
- auto map2 = indexingMaps[2].cast<AffineMapAttr>().getValue();
+ auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
+ auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
+ auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
@@ -47,9 +47,9 @@ bool mlir::isColumnMajorMatmul(ArrayAttr indexingMaps) {
if (indexingMaps.size() != 3)
return false;
- auto map0 = indexingMaps[0].cast<AffineMapAttr>().getValue();
- auto map1 = indexingMaps[1].cast<AffineMapAttr>().getValue();
- auto map2 = indexingMaps[2].cast<AffineMapAttr>().getValue();
+ auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
+ auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
+ auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
if (map0.getNumResults() != 2 || map1.getNumResults() != 2 ||
map2.getNumResults() != 2 || map0.getNumInputs() != 3 ||
@@ -73,9 +73,9 @@ bool mlir::isRowMajorBatchMatmul(ArrayAttr indexingMaps) {
if (indexingMaps.size() != 3)
return false;
- auto map0 = indexingMaps[0].cast<AffineMapAttr>().getValue();
- auto map1 = indexingMaps[1].cast<AffineMapAttr>().getValue();
- auto map2 = indexingMaps[2].cast<AffineMapAttr>().getValue();
+ auto map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
+ auto map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue();
+ auto map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue();
if (map0.getNumResults() != 3 || map1.getNumResults() != 3 ||
map2.getNumResults() != 3 || map0.getNumInputs() != 4 ||
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index a6431043475a..ad7e367c71ab 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -30,14 +30,14 @@ struct TransferReadOpInterface
vector::TransferReadOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
- assert(opOperand.get().getType().isa<RankedTensorType>() &&
+ assert(isa<RankedTensorType>(opOperand.get().getType()) &&
"only tensor types expected");
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
- assert(opOperand.get().getType().isa<RankedTensorType>() &&
+ assert(isa<RankedTensorType>(opOperand.get().getType()) &&
"only tensor types expected");
return false;
}
@@ -50,7 +50,7 @@ struct TransferReadOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto readOp = cast<vector::TransferReadOp>(op);
- assert(readOp.getShapedType().isa<TensorType>() &&
+ assert(isa<TensorType>(readOp.getShapedType()) &&
"only tensor types expected");
FailureOr<Value> buffer = getBuffer(rewriter, readOp.getSource(), options);
if (failed(buffer))
@@ -74,7 +74,7 @@ struct TransferWriteOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto writeOp = cast<vector::TransferWriteOp>(op);
- assert(writeOp.getShapedType().isa<TensorType>() &&
+ assert(isa<TensorType>(writeOp.getShapedType()) &&
"only tensor types expected");
// Create a new transfer_write on buffer that doesn't have a return value.
@@ -99,14 +99,14 @@ struct GatherOpInterface
vector::GatherOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
- assert(opOperand.get().getType().isa<RankedTensorType>() &&
+ assert(isa<RankedTensorType>(opOperand.get().getType()) &&
"only tensor types expected");
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
- assert(opOperand.get().getType().isa<RankedTensorType>() &&
+ assert(isa<RankedTensorType>(opOperand.get().getType()) &&
"only tensor types expected");
return false;
}
@@ -119,7 +119,7 @@ struct GatherOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto gatherOp = cast<vector::GatherOp>(op);
- assert(gatherOp.getBaseType().isa<TensorType>() &&
+ assert(isa<TensorType>(gatherOp.getBaseType()) &&
"only tensor types expected");
FailureOr<Value> buffer = getBuffer(rewriter, gatherOp.getBase(), options);
if (failed(buffer))
@@ -266,7 +266,7 @@ struct YieldOpInterface
// may get dropped during the bufferization of vector.mask.
SmallVector<Value> newResults;
for (Value value : yieldOp.getOperands()) {
- if (value.getType().isa<TensorType>()) {
+ if (isa<TensorType>(value.getType())) {
FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
if (failed(maybeBuffer))
return failure();
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index ad538fe4a682..7c606e0c35f0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -49,7 +49,7 @@ public:
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
VectorType dstType = op.getResultVectorType();
- VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
+ VectorType srcType = dyn_cast<VectorType>(op.getSourceType());
Type eltType = dstType.getElementType();
// Scalar to any vector can use splat.
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 16751f82dad2..986c5f81d60c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -96,9 +96,9 @@ static Value reshapeLoad(Location loc, Value val, VectorType type,
return rewriter.create<vector::ExtractOp>(loc, lowType, val, posAttr);
}
// Unroll leading dimensions.
- VectorType vType = lowType.cast<VectorType>();
+ VectorType vType = cast<VectorType>(lowType);
Type resType = VectorType::Builder(type).dropDim(index);
- auto resVectorType = resType.cast<VectorType>();
+ auto resVectorType = cast<VectorType>(resType);
Value result = rewriter.create<arith::ConstantOp>(
loc, resVectorType, rewriter.getZeroAttr(resVectorType));
for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
@@ -126,7 +126,7 @@ static Value reshapeStore(Location loc, Value val, Value result,
}
// Unroll leading dimensions.
Type lowType = VectorType::Builder(type).dropDim(0);
- VectorType vType = lowType.cast<VectorType>();
+ VectorType vType = cast<VectorType>(lowType);
Type insType = VectorType::Builder(vType).dropDim(0);
for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
auto posAttr = rewriter.getI64ArrayAttr(d);
@@ -160,7 +160,7 @@ createContractArithOp(Location loc, Value x, Value y, Value acc,
// Only valid for integer types.
return std::nullopt;
// Special case for fused multiply-add.
- if (acc && acc.getType().isa<VectorType>() && kind == CombiningKind::ADD) {
+ if (acc && isa<VectorType>(acc.getType()) && kind == CombiningKind::ADD) {
Value fma = rewriter.create<vector::FMAOp>(loc, x, y, acc);
if (mask)
// The fma op doesn't need explicit masking. However, fma ops used in
@@ -418,7 +418,7 @@ struct UnrolledOuterProductGenerator
Value promote(Value v, Type dstElementType) {
Type elementType = v.getType();
- auto vecType = elementType.dyn_cast<VectorType>();
+ auto vecType = dyn_cast<VectorType>(elementType);
if (vecType)
elementType = vecType.getElementType();
if (elementType == dstElementType)
@@ -426,7 +426,7 @@ struct UnrolledOuterProductGenerator
Type promotedType = dstElementType;
if (vecType)
promotedType = VectorType::get(vecType.getShape(), promotedType);
- if (dstElementType.isa<FloatType>())
+ if (isa<FloatType>(dstElementType))
return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
}
@@ -438,7 +438,7 @@ struct UnrolledOuterProductGenerator
if (mask && !maybeMask.has_value())
return failure();
- Type resElementType = res.getType().cast<VectorType>().getElementType();
+ Type resElementType = cast<VectorType>(res.getType()).getElementType();
for (int64_t k = 0; k < reductionSize; ++k) {
Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
@@ -684,7 +684,7 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
return failure();
}
- VectorType dstType = op.getResultType().cast<VectorType>();
+ VectorType dstType = cast<VectorType>(op.getResultType());
assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
"Expected dst type of rank 1 or 2");
@@ -695,7 +695,7 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
// ExtractOp does not allow dynamic indexing, we must unroll explicitly.
Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
rewriter.getZeroAttr(dstType));
- bool isInt = dstType.getElementType().isa<IntegerType>();
+ bool isInt = isa<IntegerType>(dstType.getElementType());
for (unsigned r = 0; r < dstRows; ++r) {
Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
for (unsigned c = 0; c < dstColumns; ++c) {
@@ -789,7 +789,7 @@ struct ContractOpToElementwise
} else {
// If the parallel dimension doesn't exist we will have to broadcast it.
lhsDims.push_back(
- contractOp.getResultType().cast<VectorType>().getDimSize(i));
+ cast<VectorType>(contractOp.getResultType()).getDimSize(i));
lhsTranspose.push_back(lhsDims.size() - 1);
}
std::optional<unsigned> rhsDim =
@@ -799,7 +799,7 @@ struct ContractOpToElementwise
} else {
// If the parallel dimension doesn't exist we will have to broadcast it.
rhsDims.push_back(
- contractOp.getResultType().cast<VectorType>().getDimSize(i));
+ cast<VectorType>(contractOp.getResultType()).getDimSize(i));
rhsTranspose.push_back(rhsDims.size() - 1);
}
}
@@ -969,7 +969,7 @@ FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
Value mask) const {
VectorType lhsType = op.getLhsType();
VectorType rhsType = op.getRhsType();
- VectorType resType = op.getResultType().cast<VectorType>();
+ VectorType resType = cast<VectorType>(op.getResultType());
// Find the iterator type index and result index.
SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
int64_t iterIndex = -1;
@@ -1044,10 +1044,10 @@ FailureOr<Value> ContractionOpLowering::lowerReduction(
VectorType lhsType = op.getLhsType();
VectorType rhsType = op.getRhsType();
Type resType = op.getResultType();
- if (resType.isa<VectorType>())
+ if (isa<VectorType>(resType))
return rewriter.notifyMatchFailure(op,
"did not expect a VectorType result");
- bool isInt = resType.isa<IntegerType>();
+ bool isInt = isa<IntegerType>(resType);
// Use iterator index 0.
int64_t iterIndex = 0;
SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
@@ -1133,10 +1133,10 @@ public:
auto loc = op.getLoc();
VectorType lhsType = op.getOperandVectorTypeLHS();
- VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
+ VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
VectorType resType = op.getResultVectorType();
Type eltType = resType.getElementType();
- bool isInt = eltType.isa<IntegerType, IndexType>();
+ bool isInt = isa<IntegerType, IndexType>(eltType);
Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0];
vector::CombiningKind kind = op.getKind();
@@ -1231,7 +1231,7 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
return failure();
Type dstElementType = op.getType();
- if (auto vecType = dstElementType.dyn_cast<VectorType>())
+ if (auto vecType = dyn_cast<VectorType>(dstElementType))
dstElementType = vecType.getElementType();
if (elementType != dstElementType)
return failure();
@@ -1259,8 +1259,8 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
return failure();
// At this point lhs and rhs are in row-major.
- VectorType lhsType = lhs.getType().cast<VectorType>();
- VectorType rhsType = rhs.getType().cast<VectorType>();
+ VectorType lhsType = cast<VectorType>(lhs.getType());
+ VectorType rhsType = cast<VectorType>(rhs.getType());
int64_t lhsRows = lhsType.getDimSize(0);
int64_t lhsColumns = lhsType.getDimSize(1);
int64_t rhsColumns = rhsType.getDimSize(1);
@@ -1289,7 +1289,7 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
llvm_unreachable("invalid contraction semantics");
Value res =
- elementType.isa<IntegerType>()
+ isa<IntegerType>(elementType)
? static_cast<Value>(rew.create<arith::AddIOp>(loc, op.getAcc(), mul))
: static_cast<Value>(
rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index 3f26558237a2..a0ed056fc7a3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -52,7 +52,7 @@ public:
LogicalResult matchAndRewrite(vector::CreateMaskOp op,
PatternRewriter &rewriter) const override {
- auto dstType = op.getResult().getType().cast<VectorType>();
+ auto dstType = cast<VectorType>(op.getResult().getType());
int64_t rank = dstType.getRank();
if (rank <= 1)
return rewriter.notifyMatchFailure(
@@ -112,7 +112,7 @@ public:
if (rank == 0) {
assert(dimSizes.size() == 1 &&
"Expected exactly one dim size for a 0-D vector");
- bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
+ bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1;
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, dstType,
DenseIntElementsAttr::get(
@@ -122,14 +122,14 @@ public:
}
// Scalable constant masks can only be lowered for the "none set" case.
- if (dstType.cast<VectorType>().isScalable()) {
+ if (cast<VectorType>(dstType).isScalable()) {
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, DenseElementsAttr::get(dstType, false));
return success();
}
int64_t trueDim = std::min(dstType.getDimSize(0),
- dimSizes[0].cast<IntegerAttr>().getInt());
+ cast<IntegerAttr>(dimSizes[0]).getInt());
if (rank == 1) {
// Express constant 1-D case in explicit vector form:
@@ -146,7 +146,7 @@ public:
VectorType::get(dstType.getShape().drop_front(), eltType);
SmallVector<int64_t> newDimSizes;
for (int64_t r = 1; r < rank; r++)
- newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
+ newDimSizes.push_back(cast<IntegerAttr>(dimSizes[r]).getInt());
Value trueVal = rewriter.create<vector::ConstantMaskOp>(
loc, lowType, rewriter.getI64ArrayAttr(newDimSizes));
Value result = rewriter.create<arith::ConstantOp>(
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
index eb2deba7bc46..463aab1ead38 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
@@ -48,7 +48,7 @@ static Value genOperator(Location loc, Value x, Value y,
PatternRewriter &rewriter) {
using vector::CombiningKind;
- auto elType = x.getType().cast<VectorType>().getElementType();
+ auto elType = cast<VectorType>(x.getType()).getElementType();
bool isInt = elType.isIntOrIndex();
Value combinedResult{nullptr};
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index f15d0c85fd19..4f68526ac401 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -29,7 +29,7 @@ inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
size_t index = 0;
for (unsigned pos : permutation)
newInBoundsValues[pos] =
- attr.getValue()[index++].cast<BoolAttr>().getValue();
+ cast<BoolAttr>(attr.getValue()[index++]).getValue();
return builder.getBoolArrayAttr(newInBoundsValues);
}
@@ -37,7 +37,7 @@ inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
/// dimensions.
static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
int64_t addedRank) {
- auto originalVecType = vec.getType().cast<VectorType>();
+ auto originalVecType = cast<VectorType>(vec.getType());
SmallVector<int64_t> newShape(addedRank, 1);
newShape.append(originalVecType.getShape().begin(),
originalVecType.getShape().end());
@@ -257,7 +257,7 @@ struct TransferWriteNonPermutationLowering
// All the new dimensions added are inbound.
SmallVector<bool> newInBoundsValues(missingInnerDim.size(), true);
for (Attribute attr : op.getInBounds().value().getValue()) {
- newInBoundsValues.push_back(attr.cast<BoolAttr>().getValue());
+ newInBoundsValues.push_back(cast<BoolAttr>(attr).getValue());
}
newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
}
@@ -315,7 +315,7 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
// In the meantime, lower these to a scalar load when they pop up.
if (reducedShapeRank == 0) {
Value newRead;
- if (op.getShapedType().isa<TensorType>()) {
+ if (isa<TensorType>(op.getShapedType())) {
newRead = rewriter.create<tensor::ExtractOp>(
op.getLoc(), op.getSource(), op.getIndices());
} else {
@@ -397,7 +397,7 @@ struct TransferReadToVectorLoadLowering
&broadcastedDims))
return rewriter.notifyMatchFailure(read, "not minor identity + bcast");
- auto memRefType = read.getShapedType().dyn_cast<MemRefType>();
+ auto memRefType = dyn_cast<MemRefType>(read.getShapedType());
if (!memRefType)
return rewriter.notifyMatchFailure(read, "not a memref source");
@@ -418,11 +418,11 @@ struct TransferReadToVectorLoadLowering
// `vector.load` supports vector types as memref's elements only when the
// resulting vector type is the same as the element type.
auto memrefElTy = memRefType.getElementType();
- if (memrefElTy.isa<VectorType>() && memrefElTy != unbroadcastedVectorType)
+ if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType)
return rewriter.notifyMatchFailure(read, "incompatible element type");
// Otherwise, element types of the memref and the vector must match.
- if (!memrefElTy.isa<VectorType>() &&
+ if (!isa<VectorType>(memrefElTy) &&
memrefElTy != read.getVectorType().getElementType())
return rewriter.notifyMatchFailure(read, "non-matching element type");
@@ -543,7 +543,7 @@ struct TransferWriteToVectorStoreLowering
diag << "permutation map is not minor identity: " << write;
});
- auto memRefType = write.getShapedType().dyn_cast<MemRefType>();
+ auto memRefType = dyn_cast<MemRefType>(write.getShapedType());
if (!memRefType)
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
diag << "not a memref type: " << write;
@@ -558,13 +558,13 @@ struct TransferWriteToVectorStoreLowering
// `vector.store` supports vector types as memref's elements only when the
// type of the vector value being written is the same as the element type.
auto memrefElTy = memRefType.getElementType();
- if (memrefElTy.isa<VectorType>() && memrefElTy != write.getVectorType())
+ if (isa<VectorType>(memrefElTy) && memrefElTy != write.getVectorType())
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
diag << "elemental type mismatch: " << write;
});
// Otherwise, element types of the memref and the vector must match.
- if (!memrefElTy.isa<VectorType>() &&
+ if (!isa<VectorType>(memrefElTy) &&
memrefElTy != write.getVectorType().getElementType())
return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
diag << "elemental type mismatch: " << write;
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 42c1aa58c5e5..7d804ddcfa42 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -156,7 +156,7 @@ static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2,
/// dst[511:384] := SELECT4(v2[511:0], mask[7:6])
static Value create4x128BitSuffle(ImplicitLocOpBuilder &b, Value v1, Value v2,
uint8_t mask) {
- assert(v1.getType().cast<VectorType>().getShape()[0] == 16 &&
+ assert(cast<VectorType>(v1.getType()).getShape()[0] == 16 &&
"expected a vector with length=16");
SmallVector<int64_t> shuffleMask;
auto appendToMask = [&](int64_t base, uint8_t control) {
@@ -291,7 +291,7 @@ static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m,
vs[0xf] = create4x128BitSuffle(b, t7, tf, 0xdd);
auto reshInputType = VectorType::get(
- {m, n}, source.getType().cast<VectorType>().getElementType());
+ {m, n}, cast<VectorType>(source.getType()).getElementType());
Value res =
b.create<arith::ConstantOp>(reshInputType, b.getZeroAttr(reshInputType));
for (int64_t i = 0; i < m; ++i)
@@ -329,7 +329,7 @@ public:
// Set up convenience transposition table.
SmallVector<int64_t> transp;
for (auto attr : op.getTransp())
- transp.push_back(attr.cast<IntegerAttr>().getInt());
+ transp.push_back(cast<IntegerAttr>(attr).getInt());
if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) &&
succeeded(isTranspose2DSlice(op)))
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2b5706aaa774..e56aa62a1871 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -62,8 +62,8 @@ struct DistributedLoadStoreHelper {
Value laneId, Value zero)
: sequentialVal(sequentialVal), distributedVal(distributedVal),
laneId(laneId), zero(zero) {
- sequentialVectorType = sequentialVal.getType().dyn_cast<VectorType>();
- distributedVectorType = distributedVal.getType().dyn_cast<VectorType>();
+ sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
+ distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
if (sequentialVectorType && distributedVectorType)
distributionMap =
calculateImplicitMap(sequentialVectorType, distributedVectorType);
@@ -89,7 +89,7 @@ struct DistributedLoadStoreHelper {
"Must store either the preregistered distributed or the "
"preregistered sequential value.");
// Scalar case can directly use memref.store.
- if (!val.getType().isa<VectorType>())
+ if (!isa<VectorType>(val.getType()))
return b.create<memref::StoreOp>(loc, val, buffer, zero);
// Vector case must use vector::TransferWriteOp which will later lower to
@@ -131,7 +131,7 @@ struct DistributedLoadStoreHelper {
Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) {
// Scalar case can directly use memref.store.
- if (!type.isa<VectorType>())
+ if (!isa<VectorType>(type))
return b.create<memref::LoadOp>(loc, buffer, zero);
// Other cases must be vector atm.
@@ -149,7 +149,7 @@ struct DistributedLoadStoreHelper {
}
SmallVector<bool> inBounds(indices.size(), true);
return b.create<vector::TransferReadOp>(
- loc, type.cast<VectorType>(), buffer, indices,
+ loc, cast<VectorType>(type), buffer, indices,
ArrayRef<bool>(inBounds.begin(), inBounds.end()));
}
@@ -630,14 +630,14 @@ struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
Location loc = warpOp.getLoc();
for (OpOperand &operand : elementWise->getOpOperands()) {
Type targetType;
- if (auto vecType = distributedVal.getType().dyn_cast<VectorType>()) {
+ if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) {
// If the result type is a vector, the operands must also be vectors.
- auto operandType = operand.get().getType().cast<VectorType>();
+ auto operandType = cast<VectorType>(operand.get().getType());
targetType =
VectorType::get(vecType.getShape(), operandType.getElementType());
} else {
auto operandType = operand.get().getType();
- assert(!operandType.isa<VectorType>() &&
+ assert(!isa<VectorType>(operandType) &&
"unexpected yield of vector from op with scalar result type");
targetType = operandType;
}
@@ -687,7 +687,7 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
if (!yieldOperand)
return failure();
auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
- auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
+ auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
if (!dense)
return failure();
unsigned operandIndex = yieldOperand->getOperandNumber();
@@ -737,8 +737,8 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
SmallVector<Value, 4> indices(read.getIndices().begin(),
read.getIndices().end());
- auto sequentialType = read.getResult().getType().cast<VectorType>();
- auto distributedType = distributedVal.getType().cast<VectorType>();
+ auto sequentialType = cast<VectorType>(read.getResult().getType());
+ auto distributedType = cast<VectorType>(distributedVal.getType());
AffineMap map = calculateImplicitMap(sequentialType, distributedType);
AffineMap indexMap = map.compose(read.getPermutationMap());
OpBuilder::InsertionGuard g(rewriter);
@@ -752,7 +752,7 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
unsigned indexPos = indexExpr.getPosition();
unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
int64_t scale =
- distributedVal.getType().cast<VectorType>().getDimSize(vectorPos);
+ cast<VectorType>(distributedVal.getType()).getDimSize(vectorPos);
indices[indexPos] = affine::makeComposedAffineApply(
rewriter, read.getLoc(), d0 + scale * d1,
{indices[indexPos], warpOp.getLaneid()});
@@ -845,7 +845,7 @@ struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
resultIndex = operand.getOperandNumber();
break;
}
- auto arg = operand.get().dyn_cast<BlockArgument>();
+ auto arg = dyn_cast<BlockArgument>(operand.get());
if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
continue;
Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
@@ -874,7 +874,7 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
Location loc = broadcastOp.getLoc();
auto destVecType =
- warpOp->getResultTypes()[operandNumber].cast<VectorType>();
+ cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {broadcastOp.getSource()},
@@ -914,7 +914,7 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
// Rewrite vector.extract with 1d source to vector.extractelement.
if (extractSrcType.getRank() == 1) {
assert(extractOp.getPosition().size() == 1 && "expected 1 index");
- int64_t pos = extractOp.getPosition()[0].cast<IntegerAttr>().getInt();
+ int64_t pos = cast<IntegerAttr>(extractOp.getPosition()[0]).getInt();
rewriter.setInsertionPoint(extractOp);
rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
extractOp, extractOp.getVector(),
@@ -946,8 +946,8 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
// Find the distributed dimension. There should be exactly one.
auto distributedType =
- warpOp.getResult(operandNumber).getType().cast<VectorType>();
- auto yieldedType = operand->get().getType().cast<VectorType>();
+ cast<VectorType>(warpOp.getResult(operandNumber).getType());
+ auto yieldedType = cast<VectorType>(operand->get().getType());
int64_t distributedDim = -1;
for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
@@ -1083,7 +1083,7 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
VectorType vecType = insertOp.getDestVectorType();
VectorType distrType =
- warpOp.getResult(operandNumber).getType().cast<VectorType>();
+ cast<VectorType>(warpOp.getResult(operandNumber).getType());
bool hasPos = static_cast<bool>(insertOp.getPosition());
// Yield destination vector, source scalar and position from warp op.
@@ -1171,7 +1171,7 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
// Rewrite vector.insert with 1d dest to vector.insertelement.
if (insertOp.getDestVectorType().getRank() == 1) {
assert(insertOp.getPosition().size() == 1 && "expected 1 index");
- int64_t pos = insertOp.getPosition()[0].cast<IntegerAttr>().getInt();
+ int64_t pos = cast<IntegerAttr>(insertOp.getPosition()[0]).getInt();
rewriter.setInsertionPoint(insertOp);
rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
insertOp, insertOp.getSource(), insertOp.getDest(),
@@ -1199,8 +1199,8 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
// Find the distributed dimension. There should be exactly one.
auto distrDestType =
- warpOp.getResult(operandNumber).getType().cast<VectorType>();
- auto yieldedType = operand->get().getType().cast<VectorType>();
+ cast<VectorType>(warpOp.getResult(operandNumber).getType());
+ auto yieldedType = cast<VectorType>(operand->get().getType());
int64_t distrDestDim = -1;
for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
@@ -1213,7 +1213,7 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
assert(distrDestDim != -1 && "could not find distributed dimension");
// Compute the distributed source vector type.
- VectorType srcVecType = insertOp.getSourceType().cast<VectorType>();
+ VectorType srcVecType = cast<VectorType>(insertOp.getSourceType());
SmallVector<int64_t> distrSrcShape(srcVecType.getShape().begin(),
srcVecType.getShape().end());
// E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
@@ -1248,7 +1248,7 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
SmallVector<int64_t> newPos = llvm::to_vector(
llvm::map_range(insertOp.getPosition(), [](Attribute attr) {
- return attr.cast<IntegerAttr>().getInt();
+ return cast<IntegerAttr>(attr).getInt();
}));
// tid of inserting lane: pos / elementsPerLane
Value insertingLane = rewriter.create<arith::ConstantIndexOp>(
@@ -1337,7 +1337,7 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
if (!escapingValues.insert(operand->get()))
return;
Type distType = operand->get().getType();
- if (auto vecType = distType.cast<VectorType>()) {
+ if (auto vecType = cast<VectorType>(distType)) {
AffineMap map = distributionMapFn(operand->get());
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
}
@@ -1359,7 +1359,7 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
for (OpOperand &yieldOperand : yield->getOpOperands()) {
if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
continue;
- auto forResult = yieldOperand.get().cast<OpResult>();
+ auto forResult = cast<OpResult>(yieldOperand.get());
newOperands.push_back(
newWarpOp.getResult(yieldOperand.getOperandNumber()));
yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]);
@@ -1463,7 +1463,7 @@ struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
auto reductionOp =
cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
- auto vectorType = reductionOp.getVector().getType().cast<VectorType>();
+ auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
// Only rank 1 vectors supported.
if (vectorType.getRank() != 1)
return rewriter.notifyMatchFailure(
@@ -1564,7 +1564,7 @@ void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
// operations from there.
for (auto &op : body->without_terminator()) {
bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
- return result.getType().isa<VectorType>();
+ return isa<VectorType>(result.getType());
});
if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
opsToMove.insert(&op);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 6105e87573c2..8b2444199a50 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -136,10 +136,10 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
Type oldSrcType = insertOp.getSourceType();
Type newSrcType = oldSrcType;
int64_t oldSrcRank = 0, newSrcRank = 0;
- if (auto type = oldSrcType.dyn_cast<VectorType>()) {
+ if (auto type = dyn_cast<VectorType>(oldSrcType)) {
newSrcType = trimLeadingOneDims(type);
oldSrcRank = type.getRank();
- newSrcRank = newSrcType.cast<VectorType>().getRank();
+ newSrcRank = cast<VectorType>(newSrcType).getRank();
}
VectorType oldDstType = insertOp.getDestVectorType();
@@ -199,7 +199,7 @@ struct CastAwayTransferReadLeadingOneDim
if (read.getMask())
return failure();
- auto shapedType = read.getSource().getType().cast<ShapedType>();
+ auto shapedType = cast<ShapedType>(read.getSource().getType());
if (shapedType.getElementType() != read.getVectorType().getElementType())
return failure();
@@ -247,7 +247,7 @@ struct CastAwayTransferWriteLeadingOneDim
if (write.getMask())
return failure();
- auto shapedType = write.getSource().getType().dyn_cast<ShapedType>();
+ auto shapedType = dyn_cast<ShapedType>(write.getSource().getType());
if (shapedType.getElementType() != write.getVectorType().getElementType())
return failure();
@@ -284,7 +284,7 @@ struct CastAwayTransferWriteLeadingOneDim
LogicalResult
mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
RewriterBase &rewriter) {
- VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>();
+ VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
if (oldAccType == nullptr)
return failure();
if (oldAccType.getRank() < 2)
@@ -418,7 +418,7 @@ public:
PatternRewriter &rewriter) const override {
if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
return failure();
- auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>();
+ auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0]);
if (!vecType)
return failure();
VectorType newVecType = trimLeadingOneDims(vecType);
@@ -427,7 +427,7 @@ public:
int64_t dropDim = vecType.getRank() - newVecType.getRank();
SmallVector<Value, 4> newOperands;
for (Value operand : op->getOperands()) {
- if (auto opVecType = operand.getType().dyn_cast<VectorType>()) {
+ if (auto opVecType = dyn_cast<VectorType>(operand.getType())) {
newOperands.push_back(rewriter.create<vector::ExtractOp>(
op->getLoc(), operand, splatZero(dropDim)));
} else {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 313a3f9a9c09..37216cea7b61 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -21,7 +21,7 @@ using namespace mlir::vector;
// Helper that picks the proper sequence for inserting.
static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
Value into, int64_t offset) {
- auto vectorType = into.getType().cast<VectorType>();
+ auto vectorType = cast<VectorType>(into.getType());
if (vectorType.getRank() > 1)
return rewriter.create<InsertOp>(loc, from, into, offset);
return rewriter.create<vector::InsertElementOp>(
@@ -32,7 +32,7 @@ static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
// Helper that picks the proper sequence for extracting.
static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
int64_t offset) {
- auto vectorType = vector.getType().cast<VectorType>();
+ auto vectorType = cast<VectorType>(vector.getType());
if (vectorType.getRank() > 1)
return rewriter.create<ExtractOp>(loc, vector, offset);
return rewriter.create<vector::ExtractElementOp>(
@@ -134,10 +134,10 @@ public:
}
int64_t offset =
- op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
+ cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
int64_t size = srcType.getShape().front();
int64_t stride =
- op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
+ cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
auto loc = op.getLoc();
Value res = op.getDest();
@@ -174,7 +174,7 @@ public:
off += stride, ++idx) {
// 1. extract the proper subvector (or element) from source
Value extractedSource = extractOne(rewriter, loc, op.getSource(), idx);
- if (extractedSource.getType().isa<VectorType>()) {
+ if (isa<VectorType>(extractedSource.getType())) {
// 2. If we have a vector, extract the proper subvector from destination
// Otherwise we are at the element level and no need to recurse.
Value extractedDest = extractOne(rewriter, loc, op.getDest(), off);
@@ -208,11 +208,10 @@ public:
assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
int64_t offset =
- op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
- int64_t size =
- op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
+ cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
+ int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
int64_t stride =
- op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
+ cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
assert(dstType.getElementType().isSignlessIntOrIndexOrFloat());
@@ -254,11 +253,10 @@ public:
return failure();
int64_t offset =
- op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
- int64_t size =
- op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
+ cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
+ int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
int64_t stride =
- op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
+ cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
Location loc = op.getLoc();
SmallVector<Value> elements;
@@ -300,11 +298,10 @@ public:
assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
int64_t offset =
- op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
- int64_t size =
- op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
+ cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt();
+ int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt();
int64_t stride =
- op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
+ cast<IntegerAttr>(op.getStrides().getValue().front()).getInt();
auto loc = op.getLoc();
auto elemType = dstType.getElementType();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 3a06d9bdea1f..68d8c92a94df 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -261,7 +261,7 @@ static MemRefType dropUnitDims(MemRefType inputType, ArrayRef<int64_t> offsets,
llvm::make_filter_range(sizes, [](int64_t sz) { return sz != 1; }));
Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
targetShape, inputType, offsets, sizes, strides);
- return canonicalizeStridedLayout(rankReducedType.cast<MemRefType>());
+ return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType));
}
/// Creates a rank-reducing memref.subview op that drops unit dims from its
@@ -269,7 +269,7 @@ static MemRefType dropUnitDims(MemRefType inputType, ArrayRef<int64_t> offsets,
static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
mlir::Location loc,
Value input) {
- MemRefType inputType = input.getType().cast<MemRefType>();
+ MemRefType inputType = cast<MemRefType>(input.getType());
assert(inputType.hasStaticShape());
SmallVector<int64_t> subViewOffsets(inputType.getRank(), 0);
SmallVector<int64_t> subViewStrides(inputType.getRank(), 1);
@@ -304,9 +304,9 @@ class TransferReadDropUnitDimsPattern
PatternRewriter &rewriter) const override {
auto loc = transferReadOp.getLoc();
Value vector = transferReadOp.getVector();
- VectorType vectorType = vector.getType().cast<VectorType>();
+ VectorType vectorType = cast<VectorType>(vector.getType());
Value source = transferReadOp.getSource();
- MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
+ MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
// TODO: support tensor types.
if (!sourceType || !sourceType.hasStaticShape())
return failure();
@@ -347,9 +347,9 @@ class TransferWriteDropUnitDimsPattern
PatternRewriter &rewriter) const override {
auto loc = transferWriteOp.getLoc();
Value vector = transferWriteOp.getVector();
- VectorType vectorType = vector.getType().cast<VectorType>();
+ VectorType vectorType = cast<VectorType>(vector.getType());
Value source = transferWriteOp.getSource();
- MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
+ MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
// TODO: support tensor type.
if (!sourceType || !sourceType.hasStaticShape())
return failure();
@@ -406,7 +406,7 @@ static int64_t hasMatchingInnerContigousShape(MemRefType memrefType,
/// input starting at `firstDimToCollapse`.
static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
Value input, int64_t firstDimToCollapse) {
- ShapedType inputType = input.getType().cast<ShapedType>();
+ ShapedType inputType = cast<ShapedType>(input.getType());
if (inputType.getRank() == 1)
return input;
SmallVector<ReassociationIndices> reassociation;
@@ -451,9 +451,9 @@ class FlattenContiguousRowMajorTransferReadPattern
PatternRewriter &rewriter) const override {
auto loc = transferReadOp.getLoc();
Value vector = transferReadOp.getVector();
- VectorType vectorType = vector.getType().cast<VectorType>();
+ VectorType vectorType = cast<VectorType>(vector.getType());
Value source = transferReadOp.getSource();
- MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
+ MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
// Contiguity check is valid on tensors only.
if (!sourceType)
return failure();
@@ -481,7 +481,7 @@ class FlattenContiguousRowMajorTransferReadPattern
Value collapsedSource =
collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
MemRefType collapsedSourceType =
- collapsedSource.getType().dyn_cast<MemRefType>();
+ dyn_cast<MemRefType>(collapsedSource.getType());
int64_t collapsedRank = collapsedSourceType.getRank();
assert(collapsedRank == firstContiguousInnerDim + 1);
SmallVector<AffineExpr, 1> dimExprs{
@@ -494,7 +494,7 @@ class FlattenContiguousRowMajorTransferReadPattern
loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
- transferReadOp, vector.getType().cast<VectorType>(), flatRead);
+ transferReadOp, cast<VectorType>(vector.getType()), flatRead);
return success();
}
};
@@ -511,9 +511,9 @@ class FlattenContiguousRowMajorTransferWritePattern
PatternRewriter &rewriter) const override {
auto loc = transferWriteOp.getLoc();
Value vector = transferWriteOp.getVector();
- VectorType vectorType = vector.getType().cast<VectorType>();
+ VectorType vectorType = cast<VectorType>(vector.getType());
Value source = transferWriteOp.getSource();
- MemRefType sourceType = source.getType().dyn_cast<MemRefType>();
+ MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
// Contiguity check is valid on tensors only.
if (!sourceType)
return failure();
@@ -541,7 +541,7 @@ class FlattenContiguousRowMajorTransferWritePattern
Value collapsedSource =
collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
MemRefType collapsedSourceType =
- collapsedSource.getType().cast<MemRefType>();
+ cast<MemRefType>(collapsedSource.getType());
int64_t collapsedRank = collapsedSourceType.getRank();
assert(collapsedRank == firstContiguousInnerDim + 1);
SmallVector<AffineExpr, 1> dimExprs{
@@ -610,7 +610,7 @@ class RewriteScalarExtractElementOfTransferRead
*getConstantIntValue(ofr));
}
}
- if (xferOp.getSource().getType().isa<MemRefType>()) {
+ if (isa<MemRefType>(xferOp.getSource().getType())) {
rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
newIndices);
} else {
@@ -637,7 +637,7 @@ class RewriteScalarExtractOfTransferRead
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
// Only match scalar extracts.
- if (extractOp.getType().isa<VectorType>())
+ if (isa<VectorType>(extractOp.getType()))
return failure();
auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
if (!xferOp)
@@ -660,7 +660,7 @@ class RewriteScalarExtractOfTransferRead
SmallVector<Value> newIndices(xferOp.getIndices().begin(),
xferOp.getIndices().end());
for (const auto &it : llvm::enumerate(extractOp.getPosition())) {
- int64_t offset = it.value().cast<IntegerAttr>().getInt();
+ int64_t offset = cast<IntegerAttr>(it.value()).getInt();
int64_t idx =
newIndices.size() - extractOp.getPosition().size() + it.index();
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
@@ -673,7 +673,7 @@ class RewriteScalarExtractOfTransferRead
extractOp.getLoc(), *getConstantIntValue(ofr));
}
}
- if (xferOp.getSource().getType().isa<MemRefType>()) {
+ if (isa<MemRefType>(xferOp.getSource().getType())) {
rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
newIndices);
} else {
@@ -714,7 +714,7 @@ class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
xferOp.getVector(), pos);
}
// Construct a scalar store.
- if (xferOp.getSource().getType().isa<MemRefType>()) {
+ if (isa<MemRefType>(xferOp.getSource().getType())) {
rewriter.replaceOpWithNewOp<memref::StoreOp>(
xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
} else {
@@ -732,12 +732,12 @@ void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
// Run store to load forwarding first since it can expose more dead store
// opportunity.
rootOp->walk([&](vector::TransferReadOp read) {
- if (read.getShapedType().isa<MemRefType>())
+ if (isa<MemRefType>(read.getShapedType()))
opt.storeToLoadForwarding(read);
});
opt.removeDeadOp();
rootOp->walk([&](vector::TransferWriteOp write) {
- if (write.getShapedType().isa<MemRefType>())
+ if (isa<MemRefType>(write.getShapedType()))
opt.deadStoreOp(write);
});
opt.removeDeadOp();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
index 34a7ce16ce98..6dacb1e199f3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
@@ -190,7 +190,7 @@ createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
Location loc = xferOp.getLoc();
int64_t memrefRank = xferOp.getShapedType().getRank();
// TODO: relax this precondition, will require rank-reducing subviews.
- assert(memrefRank == alloc.getType().cast<MemRefType>().getRank() &&
+ assert(memrefRank == cast<MemRefType>(alloc.getType()).getRank() &&
"Expected memref rank to match the alloc rank");
ValueRange leadingIndices =
xferOp.indices().take_front(xferOp.getLeadingShapedRank());
@@ -571,8 +571,8 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
}
MemRefType compatibleMemRefType =
- getCastCompatibleMemRefType(xferOp.getShapedType().cast<MemRefType>(),
- alloc.getType().cast<MemRefType>());
+ getCastCompatibleMemRefType(cast<MemRefType>(xferOp.getShapedType()),
+ cast<MemRefType>(alloc.getType()));
if (!compatibleMemRefType)
return failure();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 44f3a10c4da5..d634d6a19030 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -93,9 +93,9 @@ struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
PatternRewriter &rewriter) const override {
// Check if 'shapeCastOp' has vector source/result type.
auto sourceVectorType =
- shapeCastOp.getSource().getType().dyn_cast_or_null<VectorType>();
+ dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
auto resultVectorType =
- shapeCastOp.getResult().getType().dyn_cast_or_null<VectorType>();
+ dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
if (!sourceVectorType || !resultVectorType)
return failure();
@@ -105,7 +105,7 @@ struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
if (!sourceShapeCastOp)
return failure();
auto operandSourceVectorType =
- sourceShapeCastOp.getSource().getType().cast<VectorType>();
+ cast<VectorType>(sourceShapeCastOp.getSource().getType());
auto operandResultVectorType = sourceShapeCastOp.getType();
// Check if shape cast operations invert each other.
@@ -342,7 +342,7 @@ struct CombineContractBroadcast
if (!broadcast)
continue;
// contractionOp can only take vector as operands.
- auto srcType = broadcast.getSourceType().dyn_cast<VectorType>();
+ auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
if (!srcType ||
srcType.getRank() == broadcast.getResultVectorType().getRank())
continue;
@@ -455,7 +455,7 @@ struct ReorderCastOpsOnBroadcast
return failure();
Type castResTy = getElementTypeOrSelf(op->getResult(0));
- if (auto vecTy = bcastOp.getSourceType().dyn_cast<VectorType>())
+ if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
castResTy = VectorType::get(vecTy.getShape(), castResTy);
auto *castOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(),
@@ -530,7 +530,7 @@ struct ReorderElementwiseOpsOnTranspose final
// This is a constant. Create a reverse transpose op for it.
auto vectorType = VectorType::get(
srcType.getShape(),
- operand.getType().cast<VectorType>().getElementType());
+ cast<VectorType>(operand.getType()).getElementType());
srcValues.push_back(rewriter.create<vector::TransposeOp>(
operand.getLoc(), vectorType, operand,
rewriter.getI64ArrayAttr(invOrder)));
@@ -539,7 +539,7 @@ struct ReorderElementwiseOpsOnTranspose final
auto vectorType = VectorType::get(
srcType.getShape(),
- op->getResultTypes()[0].cast<VectorType>().getElementType());
+ cast<VectorType>(op->getResultTypes()[0]).getElementType());
Operation *elementwiseOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
vectorType, op->getAttrs());
@@ -693,7 +693,7 @@ struct BubbleDownBitCastForStridedSliceExtract
}
SmallVector<int64_t> dims =
- llvm::to_vector<4>(extractOp.getType().cast<VectorType>().getShape());
+ llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
dims.back() = dims.back() / expandRatio;
VectorType newExtractType =
VectorType::get(dims, castSrcType.getElementType());
@@ -996,7 +996,7 @@ public:
LogicalResult matchAndRewrite(vector::CreateMaskOp op,
PatternRewriter &rewriter) const override {
auto dstType = op.getType();
- if (dstType.cast<VectorType>().isScalable())
+ if (cast<VectorType>(dstType).isScalable())
return failure();
int64_t rank = dstType.getRank();
if (rank > 1)
@@ -1026,7 +1026,7 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
if (readOp.getMask())
return failure();
- auto srcType = readOp.getSource().getType().dyn_cast<MemRefType>();
+ auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType());
if (!srcType || !srcType.hasStaticShape())
return failure();
@@ -1060,13 +1060,13 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
MemRefType resultMemrefType;
MemRefLayoutAttrInterface layout = srcType.getLayout();
- if (layout.isa<AffineMapAttr>() && layout.isIdentity()) {
+ if (isa<AffineMapAttr>(layout) && layout.isIdentity()) {
resultMemrefType = MemRefType::get(
srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
nullptr, srcType.getMemorySpace());
} else {
MemRefLayoutAttrInterface updatedLayout;
- if (auto strided = layout.dyn_cast<StridedLayoutAttr>()) {
+ if (auto strided = dyn_cast<StridedLayoutAttr>(layout)) {
auto strides =
llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
updatedLayout = StridedLayoutAttr::get(strided.getContext(),
@@ -1099,7 +1099,7 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(),
strides);
auto permMap = getTransferMinorIdentityMap(
- rankedReducedView.getType().cast<ShapedType>(), resultTargetVecType);
+ cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
Value result = rewriter.create<vector::TransferReadOp>(
loc, resultTargetVecType, rankedReducedView,
readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index f56e7cf25603..5eee318b51b3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -316,7 +316,7 @@ struct UnrollContractionPattern
auto targetShape = getTargetShape(options, contractOp);
if (!targetShape)
return failure();
- auto dstVecType = contractOp.getResultType().cast<VectorType>();
+ auto dstVecType = cast<VectorType>(contractOp.getResultType());
SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
Location loc = contractOp.getLoc();
@@ -491,7 +491,7 @@ struct UnrollElementwisePattern : public RewritePattern {
auto targetShape = getTargetShape(options, op);
if (!targetShape)
return failure();
- auto dstVecType = op->getResult(0).getType().cast<VectorType>();
+ auto dstVecType = cast<VectorType>(op->getResult(0).getType());
SmallVector<int64_t> originalSize =
*cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
SmallVector<int64_t> ratio = *computeShapeRatio(originalSize, *targetShape);
@@ -512,7 +512,7 @@ struct UnrollElementwisePattern : public RewritePattern {
getVectorOffset(ratioStrides, i, *targetShape);
SmallVector<Value> extractOperands;
for (OpOperand &operand : op->getOpOperands()) {
- auto vecType = operand.get().getType().template dyn_cast<VectorType>();
+ auto vecType = dyn_cast<VectorType>(operand.get().getType());
if (!vecType) {
extractOperands.push_back(operand.get());
continue;
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index e77a13a9c653..a1451fbf7f31 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -36,9 +36,9 @@ using namespace mlir;
/// the type of `source`.
Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
int64_t dim) {
- if (source.getType().isa<UnrankedMemRefType, MemRefType>())
+ if (isa<UnrankedMemRefType, MemRefType>(source.getType()))
return b.createOrFold<memref::DimOp>(loc, source, dim);
- if (source.getType().isa<UnrankedTensorType, RankedTensorType>())
+ if (isa<UnrankedTensorType, RankedTensorType>(source.getType()))
return b.createOrFold<tensor::DimOp>(loc, source, dim);
llvm_unreachable("Expected MemRefType or TensorType");
}
@@ -89,7 +89,7 @@ mlir::vector::isTranspose2DSlice(vector::TransposeOp op) {
SmallVector<int64_t> transp;
for (auto attr : op.getTransp())
- transp.push_back(attr.cast<IntegerAttr>().getInt());
+ transp.push_back(cast<IntegerAttr>(attr).getInt());
// Check whether the two source vector dimensions that are greater than one
// must be transposed with each other so that we can apply one of the 2-D
@@ -223,7 +223,7 @@ bool matcher::operatesOnSuperVectorsOf(Operation &op,
}
return false;
} else if (op.getNumResults() == 1) {
- if (auto v = op.getResult(0).getType().dyn_cast<VectorType>()) {
+ if (auto v = dyn_cast<VectorType>(op.getResult(0).getType())) {
superVectorType = v;
} else {
// Not a vector type.
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index e806db7b7cde..b36f2978d20e 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -22,11 +22,11 @@ using namespace mlir::x86vector;
/// Extracts the "main" vector element type from the given X86Vector operation.
template <typename OpTy>
static Type getSrcVectorElementType(OpTy op) {
- return op.getSrc().getType().template cast<VectorType>().getElementType();
+ return cast<VectorType>(op.getSrc().getType()).getElementType();
}
template <>
Type getSrcVectorElementType(Vp2IntersectOp op) {
- return op.getA().getType().template cast<VectorType>().getElementType();
+ return cast<VectorType>(op.getA().getType()).getElementType();
}
namespace {