summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/SPIRV
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/SPIRV')
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp4
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp14
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp12
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp59
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp4
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp36
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp4
-rw-r--r--mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp18
8 files changed, 74 insertions, 77 deletions
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
index f154840b6f65..c22cb6710a7e 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
@@ -42,8 +42,8 @@ public:
PatternRewriter &rewriter) const override {
SmallVector<NamedAttribute, 4> globalVarAttrs;
- auto ptrType = op.getType().cast<spirv::PointerType>();
- auto pointeeType = ptrType.getPointeeType().cast<spirv::StructType>();
+ auto ptrType = cast<spirv::PointerType>(op.getType());
+ auto pointeeType = cast<spirv::StructType>(ptrType.getPointeeType());
spirv::StructType structType = VulkanLayoutUtils::decorateType(pointeeType);
if (!structType)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index c0ab2152675e..9f2755da0922 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -51,19 +51,19 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
// info create a variable of type !spirv.ptr<!spirv.struct<elementType>>. If
// not it must already be a !spirv.ptr<!spirv.struct<...>>.
auto varType = funcOp.getFunctionType().getInput(argIndex);
- if (varType.cast<spirv::SPIRVType>().isScalarOrVector()) {
+ if (cast<spirv::SPIRVType>(varType).isScalarOrVector()) {
auto storageClass = abiInfo.getStorageClass();
if (!storageClass)
return nullptr;
varType =
spirv::PointerType::get(spirv::StructType::get(varType), *storageClass);
}
- auto varPtrType = varType.cast<spirv::PointerType>();
- auto varPointeeType = varPtrType.getPointeeType().cast<spirv::StructType>();
+ auto varPtrType = cast<spirv::PointerType>(varType);
+ auto varPointeeType = cast<spirv::StructType>(varPtrType.getPointeeType());
// Set the offset information.
varPointeeType =
- VulkanLayoutUtils::decorateType(varPointeeType).cast<spirv::StructType>();
+ cast<spirv::StructType>(VulkanLayoutUtils::decorateType(varPointeeType));
if (!varPointeeType)
return nullptr;
@@ -98,7 +98,7 @@ getInterfaceVariables(spirv::FuncOp funcOp,
// Starting with version 1.4, the interface’s storage classes are all
// storage classes used in declaring all global variables referenced by the
// entry point’s call tree." We should consider the target environment here.
- switch (var.getType().cast<spirv::PointerType>().getStorageClass()) {
+ switch (cast<spirv::PointerType>(var.getType()).getStorageClass()) {
case spirv::StorageClass::Input:
case spirv::StorageClass::Output:
interfaceVarSet.insert(var.getOperation());
@@ -247,7 +247,7 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
// at the start of the function. It is probably better to do the load just
// before the use. There might be multiple loads and currently there is no
// easy way to replace all uses with a sequence of operations.
- if (argType.value().cast<spirv::SPIRVType>().isScalarOrVector()) {
+ if (cast<spirv::SPIRVType>(argType.value()).isScalarOrVector()) {
auto zero =
spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter);
auto loadPtr = rewriter.create<spirv::AccessChainOp>(
@@ -287,7 +287,7 @@ void LowerABIAttributesPass::runOnOperation() {
typeConverter.addSourceMaterialization([](OpBuilder &builder,
spirv::PointerType type,
ValueRange inputs, Location loc) {
- if (inputs.size() != 1 || !inputs[0].getType().isa<spirv::PointerType>())
+ if (inputs.size() != 1 || !isa<spirv::PointerType>(inputs[0].getType()))
return Value();
return builder.create<spirv::BitcastOp>(loc, type, inputs[0]).getResult();
});
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
index 51c36bd12db1..f38282f57a2c 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
@@ -84,15 +84,13 @@ void RewriteInsertsPass::runOnOperation() {
LogicalResult RewriteInsertsPass::collectInsertionChain(
spirv::CompositeInsertOp op,
SmallVectorImpl<spirv::CompositeInsertOp> &insertions) {
- auto indicesArrayAttr = op.getIndices().cast<ArrayAttr>();
+ auto indicesArrayAttr = cast<ArrayAttr>(op.getIndices());
// TODO: handle nested composite object.
if (indicesArrayAttr.size() == 1) {
- auto numElements = op.getComposite()
- .getType()
- .cast<spirv::CompositeType>()
+ auto numElements = cast<spirv::CompositeType>(op.getComposite().getType())
.getNumElements();
- auto index = indicesArrayAttr[0].cast<IntegerAttr>().getInt();
+ auto index = cast<IntegerAttr>(indicesArrayAttr[0]).getInt();
// Need a last index to collect a sequential chain.
if (index + 1 != numElements)
return failure();
@@ -109,9 +107,9 @@ LogicalResult RewriteInsertsPass::collectInsertionChain(
return failure();
--index;
- indicesArrayAttr = op.getIndices().cast<ArrayAttr>();
+ indicesArrayAttr = cast<ArrayAttr>(op.getIndices());
if ((indicesArrayAttr.size() != 1) ||
- (indicesArrayAttr[0].cast<IntegerAttr>().getInt() != index))
+ (cast<IntegerAttr>(indicesArrayAttr[0]).getInt() != index))
return failure();
}
}
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 5a5cdfe34194..793b02520f23 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -139,7 +139,7 @@ bool SPIRVTypeConverter::allows(spirv::Capability capability) {
// SPIR-V dialect. Keeping it local till the use case arises.
static std::optional<int64_t>
getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
- if (type.isa<spirv::ScalarType>()) {
+ if (isa<spirv::ScalarType>(type)) {
auto bitWidth = type.getIntOrFloatBitWidth();
// According to the SPIR-V spec:
// "There is no physical size or bit pattern defined for values with boolean
@@ -152,21 +152,21 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
return bitWidth / 8;
}
- if (auto complexType = type.dyn_cast<ComplexType>()) {
+ if (auto complexType = dyn_cast<ComplexType>(type)) {
auto elementSize = getTypeNumBytes(options, complexType.getElementType());
if (!elementSize)
return std::nullopt;
return 2 * *elementSize;
}
- if (auto vecType = type.dyn_cast<VectorType>()) {
+ if (auto vecType = dyn_cast<VectorType>(type)) {
auto elementSize = getTypeNumBytes(options, vecType.getElementType());
if (!elementSize)
return std::nullopt;
return vecType.getNumElements() * *elementSize;
}
- if (auto memRefType = type.dyn_cast<MemRefType>()) {
+ if (auto memRefType = dyn_cast<MemRefType>(type)) {
// TODO: Layout should also be controlled by the ABI attributes. For now
// using the layout from MemRef.
int64_t offset;
@@ -198,7 +198,7 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
return (offset + memrefSize) * *elementSize;
}
- if (auto tensorType = type.dyn_cast<TensorType>()) {
+ if (auto tensorType = dyn_cast<TensorType>(type)) {
if (!tensorType.hasStaticShape())
return std::nullopt;
@@ -246,12 +246,12 @@ convertScalarType(const spirv::TargetEnv &targetEnv,
return nullptr;
}
- if (auto floatType = type.dyn_cast<FloatType>()) {
+ if (auto floatType = dyn_cast<FloatType>(type)) {
LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
return Builder(targetEnv.getContext()).getF32Type();
}
- auto intType = type.cast<IntegerType>();
+ auto intType = cast<IntegerType>(type);
LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
return IntegerType::get(targetEnv.getContext(), /*width=*/32,
intType.getSignedness());
@@ -319,8 +319,8 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
// Get extension and capability requirements for the given type.
SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
- type.cast<spirv::CompositeType>().getExtensions(extensions, storageClass);
- type.cast<spirv::CompositeType>().getCapabilities(capabilities, storageClass);
+ cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
+ cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
// If all requirements are met, then we can accept this type as-is.
if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
@@ -415,8 +415,8 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
<< "using non-8-bit storage for bool types unimplemented");
return nullptr;
}
- auto elementType = IntegerType::get(type.getContext(), numBoolBits)
- .dyn_cast<spirv::ScalarType>();
+ auto elementType = dyn_cast<spirv::ScalarType>(
+ IntegerType::get(type.getContext(), numBoolBits));
if (!elementType)
return nullptr;
Type arrayElemType =
@@ -487,7 +487,7 @@ static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
const SPIRVConversionOptions &options,
MemRefType type) {
- auto attr = type.getMemorySpace().dyn_cast_or_null<spirv::StorageClassAttr>();
+ auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
if (!attr) {
LLVM_DEBUG(
llvm::dbgs()
@@ -499,7 +499,7 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
}
spirv::StorageClass storageClass = attr.getValue();
- if (type.getElementType().isa<IntegerType>()) {
+ if (isa<IntegerType>(type.getElementType())) {
if (type.getElementTypeBitWidth() == 1)
return convertBoolMemrefType(targetEnv, options, type, storageClass);
if (type.getElementTypeBitWidth() < 8)
@@ -508,17 +508,17 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
Type arrayElemType;
Type elementType = type.getElementType();
- if (auto vecType = elementType.dyn_cast<VectorType>()) {
+ if (auto vecType = dyn_cast<VectorType>(elementType)) {
arrayElemType =
convertVectorType(targetEnv, options, vecType, storageClass);
- } else if (auto complexType = elementType.dyn_cast<ComplexType>()) {
+ } else if (auto complexType = dyn_cast<ComplexType>(elementType)) {
arrayElemType =
convertComplexType(targetEnv, options, complexType, storageClass);
- } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
+ } else if (auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) {
arrayElemType =
convertScalarType(targetEnv, options, scalarType, storageClass);
- } else if (auto indexType = elementType.dyn_cast<IndexType>()) {
- type = convertIndexElementType(type, options).cast<MemRefType>();
+ } else if (auto indexType = dyn_cast<IndexType>(elementType)) {
+ type = cast<MemRefType>(convertIndexElementType(type, options));
arrayElemType = type.getElementType();
} else {
LLVM_DEBUG(
@@ -583,7 +583,7 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
addConversion([this](IntegerType intType) -> std::optional<Type> {
- if (auto scalarType = intType.dyn_cast<spirv::ScalarType>())
+ if (auto scalarType = dyn_cast<spirv::ScalarType>(intType))
return convertScalarType(this->targetEnv, this->options, scalarType);
if (intType.getWidth() < 8)
return convertSubByteIntegerType(this->options, intType);
@@ -591,7 +591,7 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
});
addConversion([this](FloatType floatType) -> std::optional<Type> {
- if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>())
+ if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
return convertScalarType(this->targetEnv, this->options, scalarType);
return Type();
});
@@ -784,7 +784,7 @@ static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
unsigned elementCount) {
for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
- auto ptrType = varOp.getType().dyn_cast<spirv::PointerType>();
+ auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
if (!ptrType)
continue;
@@ -792,10 +792,9 @@ static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
// block statically used per shader entry point." So we should always reuse
// the existing one.
if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
- auto numElements = ptrType.getPointeeType()
- .cast<spirv::StructType>()
- .getElementType(0)
- .cast<spirv::ArrayType>()
+ auto numElements = cast<spirv::ArrayType>(
+ cast<spirv::StructType>(ptrType.getPointeeType())
+ .getElementType(0))
.getNumElements();
if (numElements == elementCount)
return varOp;
@@ -926,8 +925,8 @@ Value mlir::spirv::getOpenCLElementPtr(SPIRVTypeConverter &typeConverter,
linearizeIndex(indices, strides, offset, indexType, loc, builder);
}
Type pointeeType =
- basePtr.getType().cast<spirv::PointerType>().getPointeeType();
- if (pointeeType.isa<spirv::ArrayType>()) {
+ cast<spirv::PointerType>(basePtr.getType()).getPointeeType();
+ if (isa<spirv::ArrayType>(pointeeType)) {
linearizedIndices.push_back(linearIndex);
return builder.create<spirv::AccessChainOp>(loc, basePtr,
linearizedIndices);
@@ -1015,7 +1014,7 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
// Ensure that all types have been converted to SPIRV types.
if (llvm::any_of(valueTypes,
- [](Type t) { return !t.isa<spirv::SPIRVType>(); }))
+ [](Type t) { return !isa<spirv::SPIRVType>(t); }))
return false;
// Special treatment for global variables, whose type requirements are
@@ -1029,13 +1028,13 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
for (Type valueType : valueTypes) {
typeExtensions.clear();
- valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
+ cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
typeExtensions)))
return false;
typeCapabilities.clear();
- valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
+ cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
typeCapabilities)))
return false;
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
index 3cd4937e96f2..44fea8678559 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp
@@ -41,7 +41,7 @@ namespace {
//===----------------------------------------------------------------------===//
Attribute getScalarOrSplatAttr(Type type, int64_t value) {
APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value);
- if (auto intTy = type.dyn_cast<IntegerType>())
+ if (auto intTy = dyn_cast<IntegerType>(type))
return IntegerAttr::get(intTy, sizedValue);
return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue);
@@ -149,7 +149,7 @@ struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> {
// Currently, WGSL only supports 32-bit integer types. Any other integer
// types should already have been promoted/demoted to i32.
- auto elemTy = getElementTypeOrSelf(lhs.getType()).cast<IntegerType>();
+ auto elemTy = cast<IntegerType>(getElementTypeOrSelf(lhs.getType()));
if (elemTy.getIntOrFloatBitWidth() != 32)
return rewriter.notifyMatchFailure(
loc,
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
index 97f16d1b1b95..ea856c748677 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
@@ -65,16 +65,16 @@ static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) {
/// `!spirv.ptr<!spirv.struct<!spirv.rtarray<...>>>`. Returns null type
/// otherwise.
static Type getRuntimeArrayElementType(Type type) {
- auto ptrType = type.dyn_cast<spirv::PointerType>();
+ auto ptrType = dyn_cast<spirv::PointerType>(type);
if (!ptrType)
return {};
- auto structType = ptrType.getPointeeType().dyn_cast<spirv::StructType>();
+ auto structType = dyn_cast<spirv::StructType>(ptrType.getPointeeType());
if (!structType || structType.getNumElements() != 1)
return {};
auto rtArrayType =
- structType.getElementType(0).dyn_cast<spirv::RuntimeArrayType>();
+ dyn_cast<spirv::RuntimeArrayType>(structType.getElementType(0));
if (!rtArrayType)
return {};
@@ -97,7 +97,7 @@ deduceCanonicalResource(ArrayRef<spirv::SPIRVType> types) {
for (const auto &indexedTypes : llvm::enumerate(types)) {
spirv::SPIRVType type = indexedTypes.value();
assert(type.isScalarOrVector());
- if (auto vectorType = type.dyn_cast<VectorType>()) {
+ if (auto vectorType = dyn_cast<VectorType>(type)) {
if (vectorType.getNumElements() % 2 != 0)
return std::nullopt; // Odd-sized vector has special layout
// requirements.
@@ -277,7 +277,7 @@ void ResourceAliasAnalysis::recordIfUnifiable(
if (!elementType)
return; // Unexpected resource variable type.
- auto type = elementType.cast<spirv::SPIRVType>();
+ auto type = cast<spirv::SPIRVType>(elementType);
if (!type.isScalarOrVector())
return; // Unexpected resource element type.
@@ -370,7 +370,7 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
Location loc = acOp.getLoc();
- if (srcElemType.isIntOrFloat() && dstElemType.isa<VectorType>()) {
+ if (srcElemType.isIntOrFloat() && isa<VectorType>(dstElemType)) {
// The source indices are for a buffer with scalar element types. Rewrite
// them into a buffer with vector element types. We need to scale the last
// index for the vector as a whole, then add one level of index for inside
@@ -398,7 +398,7 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
}
if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
- (srcElemType.isa<VectorType>() && dstElemType.isa<VectorType>())) {
+ (isa<VectorType>(srcElemType) && isa<VectorType>(dstElemType))) {
// The source indices are for a buffer with larger bitwidth scalar/vector
// element types. Rewrite them into a buffer with smaller bitwidth element
// types. We only need to scale the last index.
@@ -433,10 +433,10 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
LogicalResult
matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto srcPtrType = loadOp.getPtr().getType().cast<spirv::PointerType>();
- auto srcElemType = srcPtrType.getPointeeType().cast<spirv::SPIRVType>();
- auto dstPtrType = adaptor.getPtr().getType().cast<spirv::PointerType>();
- auto dstElemType = dstPtrType.getPointeeType().cast<spirv::SPIRVType>();
+ auto srcPtrType = cast<spirv::PointerType>(loadOp.getPtr().getType());
+ auto srcElemType = cast<spirv::SPIRVType>(srcPtrType.getPointeeType());
+ auto dstPtrType = cast<spirv::PointerType>(adaptor.getPtr().getType());
+ auto dstElemType = cast<spirv::SPIRVType>(dstPtrType.getPointeeType());
Location loc = loadOp.getLoc();
auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.getPtr());
@@ -454,7 +454,7 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
}
if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
- (srcElemType.isa<VectorType>() && dstElemType.isa<VectorType>())) {
+ (isa<VectorType>(srcElemType) && isa<VectorType>(dstElemType))) {
// The source and destination have scalar types of different bitwidths, or
// vector types of different component counts. For such cases, we load
// multiple smaller bitwidth values and construct a larger bitwidth one.
@@ -495,13 +495,13 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
// type.
Type vectorType = srcElemType;
- if (!srcElemType.isa<VectorType>())
+ if (!isa<VectorType>(srcElemType))
vectorType = VectorType::get({ratio}, dstElemType);
// If both the source and destination are vector types, we need to make
// sure the scalar type is the same for composite construction later.
- if (auto srcElemVecType = srcElemType.dyn_cast<VectorType>())
- if (auto dstElemVecType = dstElemType.dyn_cast<VectorType>()) {
+ if (auto srcElemVecType = dyn_cast<VectorType>(srcElemType))
+ if (auto dstElemVecType = dyn_cast<VectorType>(dstElemType)) {
if (srcElemVecType.getElementType() !=
dstElemVecType.getElementType()) {
int64_t count =
@@ -515,7 +515,7 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
Value vectorValue = rewriter.create<spirv::CompositeConstructOp>(
loc, vectorType, components);
- if (!srcElemType.isa<VectorType>())
+ if (!isa<VectorType>(srcElemType))
vectorValue =
rewriter.create<spirv::BitcastOp>(loc, srcElemType, vectorValue);
rewriter.replaceOp(loadOp, vectorValue);
@@ -534,9 +534,9 @@ struct ConvertStore : public ConvertAliasResource<spirv::StoreOp> {
matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcElemType =
- storeOp.getPtr().getType().cast<spirv::PointerType>().getPointeeType();
+ cast<spirv::PointerType>(storeOp.getPtr().getType()).getPointeeType();
auto dstElemType =
- adaptor.getPtr().getType().cast<spirv::PointerType>().getPointeeType();
+ cast<spirv::PointerType>(adaptor.getPtr().getType()).getPointeeType();
if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
return rewriter.notifyMatchFailure(storeOp, "not scalar type");
if (!areSameBitwidthScalarType(srcElemType, dstElemType))
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
index 6e09a848c494..095db6b815f5 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -159,13 +159,13 @@ void UpdateVCEPass::runOnOperation() {
SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
for (Type valueType : valueTypes) {
typeExtensions.clear();
- valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
+ cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
if (failed(checkAndUpdateExtensionRequirements(
op, targetEnv, typeExtensions, deducedExtensions)))
return WalkResult::interrupt();
typeCapabilities.clear();
- valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
+ cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
if (failed(checkAndUpdateCapabilityRequirements(
op, targetEnv, typeCapabilities, deducedCapabilities)))
return WalkResult::interrupt();
diff --git a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
index 67d61f820b62..b19495bc3744 100644
--- a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
@@ -53,7 +53,7 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
// must be a runtime array.
assert(memberSize != std::numeric_limits<Size>().max() ||
(i + 1 == e &&
- structType.getElementType(i).isa<spirv::RuntimeArrayType>()));
+ isa<spirv::RuntimeArrayType>(structType.getElementType(i))));
// According to the Vulkan spec:
// "A structure has a base alignment equal to the largest base alignment of
// any of its members."
@@ -79,23 +79,23 @@ VulkanLayoutUtils::decorateType(spirv::StructType structType,
Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
VulkanLayoutUtils::Size &alignment) {
- if (type.isa<spirv::ScalarType>()) {
+ if (isa<spirv::ScalarType>(type)) {
alignment = getScalarTypeAlignment(type);
// Vulkan spec does not specify any padding for a scalar type.
size = alignment;
return type;
}
- if (auto structType = type.dyn_cast<spirv::StructType>())
+ if (auto structType = dyn_cast<spirv::StructType>(type))
return decorateType(structType, size, alignment);
- if (auto arrayType = type.dyn_cast<spirv::ArrayType>())
+ if (auto arrayType = dyn_cast<spirv::ArrayType>(type))
return decorateType(arrayType, size, alignment);
- if (auto vectorType = type.dyn_cast<VectorType>())
+ if (auto vectorType = dyn_cast<VectorType>(type))
return decorateType(vectorType, size, alignment);
- if (auto arrayType = type.dyn_cast<spirv::RuntimeArrayType>()) {
+ if (auto arrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
size = std::numeric_limits<Size>().max();
return decorateType(arrayType, alignment);
}
- if (type.isa<spirv::PointerType>()) {
+ if (isa<spirv::PointerType>(type)) {
// TODO: Add support for `PhysicalStorageBufferAddresses`.
return nullptr;
}
@@ -161,13 +161,13 @@ VulkanLayoutUtils::getScalarTypeAlignment(Type scalarType) {
}
bool VulkanLayoutUtils::isLegalType(Type type) {
- auto ptrType = type.dyn_cast<spirv::PointerType>();
+ auto ptrType = dyn_cast<spirv::PointerType>(type);
if (!ptrType) {
return true;
}
auto storageClass = ptrType.getStorageClass();
- auto structType = ptrType.getPointeeType().dyn_cast<spirv::StructType>();
+ auto structType = dyn_cast<spirv::StructType>(ptrType.getPointeeType());
if (!structType) {
return true;
}