summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp')
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp14
1 files changed, 7 insertions, 7 deletions
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index c0ab2152675e..9f2755da0922 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -51,19 +51,19 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
// info create a variable of type !spirv.ptr<!spirv.struct<elementType>>. If
// not it must already be a !spirv.ptr<!spirv.struct<...>>.
auto varType = funcOp.getFunctionType().getInput(argIndex);
- if (varType.cast<spirv::SPIRVType>().isScalarOrVector()) {
+ if (cast<spirv::SPIRVType>(varType).isScalarOrVector()) {
auto storageClass = abiInfo.getStorageClass();
if (!storageClass)
return nullptr;
varType =
spirv::PointerType::get(spirv::StructType::get(varType), *storageClass);
}
- auto varPtrType = varType.cast<spirv::PointerType>();
- auto varPointeeType = varPtrType.getPointeeType().cast<spirv::StructType>();
+ auto varPtrType = cast<spirv::PointerType>(varType);
+ auto varPointeeType = cast<spirv::StructType>(varPtrType.getPointeeType());
// Set the offset information.
varPointeeType =
- VulkanLayoutUtils::decorateType(varPointeeType).cast<spirv::StructType>();
+ cast<spirv::StructType>(VulkanLayoutUtils::decorateType(varPointeeType));
if (!varPointeeType)
return nullptr;
@@ -98,7 +98,7 @@ getInterfaceVariables(spirv::FuncOp funcOp,
// Starting with version 1.4, the interface’s storage classes are all
// storage classes used in declaring all global variables referenced by the
// entry point’s call tree." We should consider the target environment here.
- switch (var.getType().cast<spirv::PointerType>().getStorageClass()) {
+ switch (cast<spirv::PointerType>(var.getType()).getStorageClass()) {
case spirv::StorageClass::Input:
case spirv::StorageClass::Output:
interfaceVarSet.insert(var.getOperation());
@@ -247,7 +247,7 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
// at the start of the function. It is probably better to do the load just
// before the use. There might be multiple loads and currently there is no
// easy way to replace all uses with a sequence of operations.
- if (argType.value().cast<spirv::SPIRVType>().isScalarOrVector()) {
+ if (cast<spirv::SPIRVType>(argType.value()).isScalarOrVector()) {
auto zero =
spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter);
auto loadPtr = rewriter.create<spirv::AccessChainOp>(
@@ -287,7 +287,7 @@ void LowerABIAttributesPass::runOnOperation() {
typeConverter.addSourceMaterialization([](OpBuilder &builder,
spirv::PointerType type,
ValueRange inputs, Location loc) {
- if (inputs.size() != 1 || !inputs[0].getType().isa<spirv::PointerType>())
+ if (inputs.size() != 1 || !isa<spirv::PointerType>(inputs[0].getType()))
return Value();
return builder.create<spirv::BitcastOp>(loc, type, inputs[0]).getResult();
});