diff options
Diffstat (limited to 'mlir/lib/Dialect/SPIRV/Transforms')
7 files changed, 65 insertions, 68 deletions
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp index f154840b6f65..c22cb6710a7e 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp @@ -42,8 +42,8 @@ public: PatternRewriter &rewriter) const override { SmallVector<NamedAttribute, 4> globalVarAttrs; - auto ptrType = op.getType().cast<spirv::PointerType>(); - auto pointeeType = ptrType.getPointeeType().cast<spirv::StructType>(); + auto ptrType = cast<spirv::PointerType>(op.getType()); + auto pointeeType = cast<spirv::StructType>(ptrType.getPointeeType()); spirv::StructType structType = VulkanLayoutUtils::decorateType(pointeeType); if (!structType) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index c0ab2152675e..9f2755da0922 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -51,19 +51,19 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp, // info create a variable of type !spirv.ptr<!spirv.struct<elementType>>. If // not it must already be a !spirv.ptr<!spirv.struct<...>>. auto varType = funcOp.getFunctionType().getInput(argIndex); - if (varType.cast<spirv::SPIRVType>().isScalarOrVector()) { + if (cast<spirv::SPIRVType>(varType).isScalarOrVector()) { auto storageClass = abiInfo.getStorageClass(); if (!storageClass) return nullptr; varType = spirv::PointerType::get(spirv::StructType::get(varType), *storageClass); } - auto varPtrType = varType.cast<spirv::PointerType>(); - auto varPointeeType = varPtrType.getPointeeType().cast<spirv::StructType>(); + auto varPtrType = cast<spirv::PointerType>(varType); + auto varPointeeType = cast<spirv::StructType>(varPtrType.getPointeeType()); // Set the offset information. varPointeeType = - VulkanLayoutUtils::decorateType(varPointeeType).cast<spirv::StructType>(); + cast<spirv::StructType>(VulkanLayoutUtils::decorateType(varPointeeType)); if (!varPointeeType) return nullptr; @@ -98,7 +98,7 @@ getInterfaceVariables(spirv::FuncOp funcOp, // Starting with version 1.4, the interface’s storage classes are all // storage classes used in declaring all global variables referenced by the // entry point’s call tree." We should consider the target environment here. - switch (var.getType().cast<spirv::PointerType>().getStorageClass()) { + switch (cast<spirv::PointerType>(var.getType()).getStorageClass()) { case spirv::StorageClass::Input: case spirv::StorageClass::Output: interfaceVarSet.insert(var.getOperation()); @@ -247,7 +247,7 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite( // at the start of the function. It is probably better to do the load just // before the use. There might be multiple loads and currently there is no // easy way to replace all uses with a sequence of operations. - if (argType.value().cast<spirv::SPIRVType>().isScalarOrVector()) { + if (cast<spirv::SPIRVType>(argType.value()).isScalarOrVector()) { auto zero = spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter); auto loadPtr = rewriter.create<spirv::AccessChainOp>( @@ -287,7 +287,7 @@ void LowerABIAttributesPass::runOnOperation() { typeConverter.addSourceMaterialization([](OpBuilder &builder, spirv::PointerType type, ValueRange inputs, Location loc) { - if (inputs.size() != 1 || !inputs[0].getType().isa<spirv::PointerType>()) + if (inputs.size() != 1 || !isa<spirv::PointerType>(inputs[0].getType())) return Value(); return builder.create<spirv::BitcastOp>(loc, type, inputs[0]).getResult(); }); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp index 51c36bd12db1..f38282f57a2c 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp @@ -84,15 +84,13 @@ void RewriteInsertsPass::runOnOperation() { LogicalResult RewriteInsertsPass::collectInsertionChain( spirv::CompositeInsertOp op, SmallVectorImpl<spirv::CompositeInsertOp> &insertions) { - auto indicesArrayAttr = op.getIndices().cast<ArrayAttr>(); + auto indicesArrayAttr = cast<ArrayAttr>(op.getIndices()); // TODO: handle nested composite object. if (indicesArrayAttr.size() == 1) { - auto numElements = op.getComposite() - .getType() - .cast<spirv::CompositeType>() + auto numElements = cast<spirv::CompositeType>(op.getComposite().getType()) .getNumElements(); - auto index = indicesArrayAttr[0].cast<IntegerAttr>().getInt(); + auto index = cast<IntegerAttr>(indicesArrayAttr[0]).getInt(); // Need a last index to collect a sequential chain. if (index + 1 != numElements) return failure(); @@ -109,9 +107,9 @@ LogicalResult RewriteInsertsPass::collectInsertionChain( return failure(); --index; - indicesArrayAttr = op.getIndices().cast<ArrayAttr>(); + indicesArrayAttr = cast<ArrayAttr>(op.getIndices()); if ((indicesArrayAttr.size() != 1) || - (indicesArrayAttr[0].cast<IntegerAttr>().getInt() != index)) + (cast<IntegerAttr>(indicesArrayAttr[0]).getInt() != index)) return failure(); } } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 5a5cdfe34194..793b02520f23 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -139,7 +139,7 @@ bool SPIRVTypeConverter::allows(spirv::Capability capability) { // SPIR-V dialect. Keeping it local till the use case arises. static std::optional<int64_t> getTypeNumBytes(const SPIRVConversionOptions &options, Type type) { - if (type.isa<spirv::ScalarType>()) { + if (isa<spirv::ScalarType>(type)) { auto bitWidth = type.getIntOrFloatBitWidth(); // According to the SPIR-V spec: // "There is no physical size or bit pattern defined for values with boolean @@ -152,21 +152,21 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) { return bitWidth / 8; } - if (auto complexType = type.dyn_cast<ComplexType>()) { + if (auto complexType = dyn_cast<ComplexType>(type)) { auto elementSize = getTypeNumBytes(options, complexType.getElementType()); if (!elementSize) return std::nullopt; return 2 * *elementSize; } - if (auto vecType = type.dyn_cast<VectorType>()) { + if (auto vecType = dyn_cast<VectorType>(type)) { auto elementSize = getTypeNumBytes(options, vecType.getElementType()); if (!elementSize) return std::nullopt; return vecType.getNumElements() * *elementSize; } - if (auto memRefType = type.dyn_cast<MemRefType>()) { + if (auto memRefType = dyn_cast<MemRefType>(type)) { // TODO: Layout should also be controlled by the ABI attributes. For now // using the layout from MemRef. int64_t offset; @@ -198,7 +198,7 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) { return (offset + memrefSize) * *elementSize; } - if (auto tensorType = type.dyn_cast<TensorType>()) { + if (auto tensorType = dyn_cast<TensorType>(type)) { if (!tensorType.hasStaticShape()) return std::nullopt; @@ -246,12 +246,12 @@ convertScalarType(const spirv::TargetEnv &targetEnv, return nullptr; } - if (auto floatType = type.dyn_cast<FloatType>()) { + if (auto floatType = dyn_cast<FloatType>(type)) { LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); return Builder(targetEnv.getContext()).getF32Type(); } - auto intType = type.cast<IntegerType>(); + auto intType = cast<IntegerType>(type); LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); return IntegerType::get(targetEnv.getContext(), /*width=*/32, intType.getSignedness()); @@ -319,8 +319,8 @@ convertVectorType(const spirv::TargetEnv &targetEnv, // Get extension and capability requirements for the given type. SmallVector<ArrayRef<spirv::Extension>, 1> extensions; SmallVector<ArrayRef<spirv::Capability>, 2> capabilities; - type.cast<spirv::CompositeType>().getExtensions(extensions, storageClass); - type.cast<spirv::CompositeType>().getCapabilities(capabilities, storageClass); + cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass); + cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass); // If all requirements are met, then we can accept this type as-is. if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && @@ -415,8 +415,8 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv, << "using non-8-bit storage for bool types unimplemented"); return nullptr; } - auto elementType = IntegerType::get(type.getContext(), numBoolBits) - .dyn_cast<spirv::ScalarType>(); + auto elementType = dyn_cast<spirv::ScalarType>( + IntegerType::get(type.getContext(), numBoolBits)); if (!elementType) return nullptr; Type arrayElemType = @@ -487,7 +487,7 @@ static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv, static Type convertMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, MemRefType type) { - auto attr = type.getMemorySpace().dyn_cast_or_null<spirv::StorageClassAttr>(); + auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace()); if (!attr) { LLVM_DEBUG( llvm::dbgs() @@ -499,7 +499,7 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv, } spirv::StorageClass storageClass = attr.getValue(); - if (type.getElementType().isa<IntegerType>()) { + if (isa<IntegerType>(type.getElementType())) { if (type.getElementTypeBitWidth() == 1) return convertBoolMemrefType(targetEnv, options, type, storageClass); if (type.getElementTypeBitWidth() < 8) @@ -508,17 +508,17 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv, Type arrayElemType; Type elementType = type.getElementType(); - if (auto vecType = elementType.dyn_cast<VectorType>()) { + if (auto vecType = dyn_cast<VectorType>(elementType)) { arrayElemType = convertVectorType(targetEnv, options, vecType, storageClass); - } else if (auto complexType = elementType.dyn_cast<ComplexType>()) { + } else if (auto complexType = dyn_cast<ComplexType>(elementType)) { arrayElemType = convertComplexType(targetEnv, options, complexType, storageClass); - } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) { + } else if (auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) { arrayElemType = convertScalarType(targetEnv, options, scalarType, storageClass); - } else if (auto indexType = elementType.dyn_cast<IndexType>()) { - type = convertIndexElementType(type, options).cast<MemRefType>(); + } else if (auto indexType = dyn_cast<IndexType>(elementType)) { + type = cast<MemRefType>(convertIndexElementType(type, options)); arrayElemType = type.getElementType(); } else { LLVM_DEBUG( @@ -583,7 +583,7 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, addConversion([this](IndexType /*indexType*/) { return getIndexType(); }); addConversion([this](IntegerType intType) -> std::optional<Type> { - if (auto scalarType = intType.dyn_cast<spirv::ScalarType>()) + if (auto scalarType = dyn_cast<spirv::ScalarType>(intType)) return convertScalarType(this->targetEnv, this->options, scalarType); if (intType.getWidth() < 8) return convertSubByteIntegerType(this->options, intType); @@ -591,7 +591,7 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, }); addConversion([this](FloatType floatType) -> std::optional<Type> { - if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>()) + if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType)) return convertScalarType(this->targetEnv, this->options, scalarType); return Type(); }); @@ -784,7 +784,7 @@ static spirv::PointerType getPushConstantStorageType(unsigned elementCount, static spirv::GlobalVariableOp getPushConstantVariable(Block &body, unsigned elementCount) { for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) { - auto ptrType = varOp.getType().dyn_cast<spirv::PointerType>(); + auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType()); if (!ptrType) continue; @@ -792,10 +792,9 @@ static spirv::GlobalVariableOp getPushConstantVariable(Block &body, // block statically used per shader entry point." So we should always reuse // the existing one. if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) { - auto numElements = ptrType.getPointeeType() - .cast<spirv::StructType>() - .getElementType(0) - .cast<spirv::ArrayType>() + auto numElements = cast<spirv::ArrayType>( + cast<spirv::StructType>(ptrType.getPointeeType()) + .getElementType(0)) .getNumElements(); if (numElements == elementCount) return varOp; @@ -926,8 +925,8 @@ Value mlir::spirv::getOpenCLElementPtr(SPIRVTypeConverter &typeConverter, linearizeIndex(indices, strides, offset, indexType, loc, builder); } Type pointeeType = - basePtr.getType().cast<spirv::PointerType>().getPointeeType(); - if (pointeeType.isa<spirv::ArrayType>()) { + cast<spirv::PointerType>(basePtr.getType()).getPointeeType(); + if (isa<spirv::ArrayType>(pointeeType)) { linearizedIndices.push_back(linearIndex); return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices); @@ -1015,7 +1014,7 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) { // Ensure that all types have been converted to SPIRV types. if (llvm::any_of(valueTypes, - [](Type t) { return !t.isa<spirv::SPIRVType>(); })) + [](Type t) { return !isa<spirv::SPIRVType>(t); })) return false; // Special treatment for global variables, whose type requirements are @@ -1029,13 +1028,13 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) { SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities; for (Type valueType : valueTypes) { typeExtensions.clear(); - valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions); + cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions); if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, typeExtensions))) return false; typeCapabilities.clear(); - valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities); + cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities); if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, typeCapabilities))) return false; diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp index 3cd4937e96f2..44fea8678559 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp @@ -41,7 +41,7 @@ namespace { //===----------------------------------------------------------------------===// Attribute getScalarOrSplatAttr(Type type, int64_t value) { APInt sizedValue(getElementTypeOrSelf(type).getIntOrFloatBitWidth(), value); - if (auto intTy = type.dyn_cast<IntegerType>()) + if (auto intTy = dyn_cast<IntegerType>(type)) return IntegerAttr::get(intTy, sizedValue); return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue); @@ -149,7 +149,7 @@ struct ExpandMulExtendedPattern final : OpRewritePattern<MulExtendedOp> { // Currently, WGSL only supports 32-bit integer types. Any other integer // types should already have been promoted/demoted to i32. - auto elemTy = getElementTypeOrSelf(lhs.getType()).cast<IntegerType>(); + auto elemTy = cast<IntegerType>(getElementTypeOrSelf(lhs.getType())); if (elemTy.getIntOrFloatBitWidth() != 32) return rewriter.notifyMatchFailure( loc, diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp index 97f16d1b1b95..ea856c748677 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp @@ -65,16 +65,16 @@ static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) { /// `!spirv.ptr<!spirv.struct<!spirv.rtarray<...>>>`. Returns null type /// otherwise. static Type getRuntimeArrayElementType(Type type) { - auto ptrType = type.dyn_cast<spirv::PointerType>(); + auto ptrType = dyn_cast<spirv::PointerType>(type); if (!ptrType) return {}; - auto structType = ptrType.getPointeeType().dyn_cast<spirv::StructType>(); + auto structType = dyn_cast<spirv::StructType>(ptrType.getPointeeType()); if (!structType || structType.getNumElements() != 1) return {}; auto rtArrayType = - structType.getElementType(0).dyn_cast<spirv::RuntimeArrayType>(); + dyn_cast<spirv::RuntimeArrayType>(structType.getElementType(0)); if (!rtArrayType) return {}; @@ -97,7 +97,7 @@ deduceCanonicalResource(ArrayRef<spirv::SPIRVType> types) { for (const auto &indexedTypes : llvm::enumerate(types)) { spirv::SPIRVType type = indexedTypes.value(); assert(type.isScalarOrVector()); - if (auto vectorType = type.dyn_cast<VectorType>()) { + if (auto vectorType = dyn_cast<VectorType>(type)) { if (vectorType.getNumElements() % 2 != 0) return std::nullopt; // Odd-sized vector has special layout // requirements. @@ -277,7 +277,7 @@ void ResourceAliasAnalysis::recordIfUnifiable( if (!elementType) return; // Unexpected resource variable type. - auto type = elementType.cast<spirv::SPIRVType>(); + auto type = cast<spirv::SPIRVType>(elementType); if (!type.isScalarOrVector()) return; // Unexpected resource element type. @@ -370,7 +370,7 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> { Location loc = acOp.getLoc(); - if (srcElemType.isIntOrFloat() && dstElemType.isa<VectorType>()) { + if (srcElemType.isIntOrFloat() && isa<VectorType>(dstElemType)) { // The source indices are for a buffer with scalar element types. Rewrite // them into a buffer with vector element types. We need to scale the last // index for the vector as a whole, then add one level of index for inside @@ -398,7 +398,7 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> { } if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) || - (srcElemType.isa<VectorType>() && dstElemType.isa<VectorType>())) { + (isa<VectorType>(srcElemType) && isa<VectorType>(dstElemType))) { // The source indices are for a buffer with larger bitwidth scalar/vector // element types. Rewrite them into a buffer with smaller bitwidth element // types. We only need to scale the last index. @@ -433,10 +433,10 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> { LogicalResult matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto srcPtrType = loadOp.getPtr().getType().cast<spirv::PointerType>(); - auto srcElemType = srcPtrType.getPointeeType().cast<spirv::SPIRVType>(); - auto dstPtrType = adaptor.getPtr().getType().cast<spirv::PointerType>(); - auto dstElemType = dstPtrType.getPointeeType().cast<spirv::SPIRVType>(); + auto srcPtrType = cast<spirv::PointerType>(loadOp.getPtr().getType()); + auto srcElemType = cast<spirv::SPIRVType>(srcPtrType.getPointeeType()); + auto dstPtrType = cast<spirv::PointerType>(adaptor.getPtr().getType()); + auto dstElemType = cast<spirv::SPIRVType>(dstPtrType.getPointeeType()); Location loc = loadOp.getLoc(); auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.getPtr()); @@ -454,7 +454,7 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> { } if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) || - (srcElemType.isa<VectorType>() && dstElemType.isa<VectorType>())) { + (isa<VectorType>(srcElemType) && isa<VectorType>(dstElemType))) { // The source and destination have scalar types of different bitwidths, or // vector types of different component counts. For such cases, we load // multiple smaller bitwidth values and construct a larger bitwidth one. @@ -495,13 +495,13 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> { // type. Type vectorType = srcElemType; - if (!srcElemType.isa<VectorType>()) + if (!isa<VectorType>(srcElemType)) vectorType = VectorType::get({ratio}, dstElemType); // If both the source and destination are vector types, we need to make // sure the scalar type is the same for composite construction later. - if (auto srcElemVecType = srcElemType.dyn_cast<VectorType>()) - if (auto dstElemVecType = dstElemType.dyn_cast<VectorType>()) { + if (auto srcElemVecType = dyn_cast<VectorType>(srcElemType)) + if (auto dstElemVecType = dyn_cast<VectorType>(dstElemType)) { if (srcElemVecType.getElementType() != dstElemVecType.getElementType()) { int64_t count = @@ -515,7 +515,7 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> { Value vectorValue = rewriter.create<spirv::CompositeConstructOp>( loc, vectorType, components); - if (!srcElemType.isa<VectorType>()) + if (!isa<VectorType>(srcElemType)) vectorValue = rewriter.create<spirv::BitcastOp>(loc, srcElemType, vectorValue); rewriter.replaceOp(loadOp, vectorValue); @@ -534,9 +534,9 @@ struct ConvertStore : public ConvertAliasResource<spirv::StoreOp> { matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcElemType = - storeOp.getPtr().getType().cast<spirv::PointerType>().getPointeeType(); + cast<spirv::PointerType>(storeOp.getPtr().getType()).getPointeeType(); auto dstElemType = - adaptor.getPtr().getType().cast<spirv::PointerType>().getPointeeType(); + cast<spirv::PointerType>(adaptor.getPtr().getType()).getPointeeType(); if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat()) return rewriter.notifyMatchFailure(storeOp, "not scalar type"); if (!areSameBitwidthScalarType(srcElemType, dstElemType)) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp index 6e09a848c494..095db6b815f5 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp @@ -159,13 +159,13 @@ void UpdateVCEPass::runOnOperation() { SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities; for (Type valueType : valueTypes) { typeExtensions.clear(); - valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions); + cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions); if (failed(checkAndUpdateExtensionRequirements( op, targetEnv, typeExtensions, deducedExtensions))) return WalkResult::interrupt(); typeCapabilities.clear(); - valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities); + cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities); if (failed(checkAndUpdateCapabilityRequirements( op, targetEnv, typeCapabilities, deducedCapabilities))) return WalkResult::interrupt(); |