summaryrefslogtreecommitdiff
path: root/mlir/lib/Rewrite
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Rewrite')
-rw-r--r--mlir/lib/Rewrite/ByteCode.cpp34
1 files changed, 17 insertions, 17 deletions
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index eca0297733e7..c8c442823781 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -402,7 +402,7 @@ struct ByteCodeWriter {
.Case<pdl::OperationType>(
[](Type) { return PDLValue::Kind::Operation; })
.Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
- if (rangeTy.getElementType().isa<pdl::TypeType>())
+ if (isa<pdl::TypeType>(rangeTy.getElementType()))
return PDLValue::Kind::TypeRange;
return PDLValue::Kind::ValueRange;
})
@@ -538,11 +538,11 @@ void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
auto processRewriterValue = [&](Value val) {
valueToMemIndex.try_emplace(val, index++);
- if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) {
+ if (pdl::RangeType rangeType = dyn_cast<pdl::RangeType>(val.getType())) {
Type elementTy = rangeType.getElementType();
- if (elementTy.isa<pdl::TypeType>())
+ if (isa<pdl::TypeType>(elementTy))
valueToRangeIndex.try_emplace(val, typeRangeIndex++);
- else if (elementTy.isa<pdl::ValueType>())
+ else if (isa<pdl::ValueType>(elementTy))
valueToRangeIndex.try_emplace(val, valueRangeIndex++);
}
};
@@ -611,13 +611,13 @@ void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
/*dummyValue*/ 0);
// Check to see if this value is a range type.
- if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) {
+ if (auto rangeTy = dyn_cast<pdl::RangeType>(value.getType())) {
Type eleType = rangeTy.getElementType();
- if (eleType.isa<pdl::OperationType>())
+ if (isa<pdl::OperationType>(eleType))
defRangeIt->second.opRangeIndex = 0;
- else if (eleType.isa<pdl::TypeType>())
+ else if (isa<pdl::TypeType>(eleType))
defRangeIt->second.typeRangeIndex = 0;
- else if (eleType.isa<pdl::ValueType>())
+ else if (isa<pdl::ValueType>(eleType))
defRangeIt->second.valueRangeIndex = 0;
}
};
@@ -792,14 +792,14 @@ void Generator::generate(pdl_interp::ApplyRewriteOp op,
#endif
// Range results also need to append the range storage index.
- if (result.getType().isa<pdl::RangeType>())
+ if (isa<pdl::RangeType>(result.getType()))
writer.append(getRangeStorageIndex(result));
writer.append(result);
}
}
void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
Value lhs = op.getLhs();
- if (lhs.getType().isa<pdl::RangeType>()) {
+ if (isa<pdl::RangeType>(lhs.getType())) {
writer.append(OpCode::AreRangesEqual);
writer.appendPDLValueKind(lhs);
writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
@@ -945,7 +945,7 @@ void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
writer.append(OpCode::GetOperands,
index.value_or(std::numeric_limits<uint32_t>::max()),
op.getInputOp());
- if (result.getType().isa<pdl::RangeType>())
+ if (isa<pdl::RangeType>(result.getType()))
writer.append(getRangeStorageIndex(result));
else
writer.append(std::numeric_limits<ByteCodeField>::max());
@@ -965,7 +965,7 @@ void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
writer.append(OpCode::GetResults,
index.value_or(std::numeric_limits<uint32_t>::max()),
op.getInputOp());
- if (result.getType().isa<pdl::RangeType>())
+ if (isa<pdl::RangeType>(result.getType()))
writer.append(getRangeStorageIndex(result));
else
writer.append(std::numeric_limits<ByteCodeField>::max());
@@ -979,7 +979,7 @@ void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
}
void Generator::generate(pdl_interp::GetValueTypeOp op,
ByteCodeWriter &writer) {
- if (op.getType().isa<pdl::RangeType>()) {
+ if (isa<pdl::RangeType>(op.getType())) {
Value result = op.getResult();
writer.append(OpCode::GetValueRangeTypes, result,
getRangeStorageIndex(result), op.getValue());
@@ -1016,7 +1016,7 @@ void Generator::generate(pdl_interp::SwitchOperandCountOp op,
void Generator::generate(pdl_interp::SwitchOperationNameOp op,
ByteCodeWriter &writer) {
auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
- return OperationName(attr.cast<StringAttr>().getValue(), ctx);
+ return OperationName(cast<StringAttr>(attr).getValue(), ctx);
});
writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
op.getSuccessors());
@@ -1566,7 +1566,7 @@ void ByteCodeExecutor::executeCheckTypes() {
Attribute rhs = read<Attribute>();
LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
- selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
+ selectJump(*lhs == cast<ArrayAttr>(rhs).getAsValueRange<TypeAttr>());
}
void ByteCodeExecutor::executeContinue() {
@@ -1581,7 +1581,7 @@ void ByteCodeExecutor::executeCreateConstantTypeRange() {
LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n");
unsigned memIndex = read();
unsigned rangeIndex = read();
- ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
+ ArrayAttr typesAttr = cast<ArrayAttr>(read<Attribute>());
LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n");
assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
@@ -1743,7 +1743,7 @@ void ByteCodeExecutor::executeGetAttributeType() {
unsigned memIndex = read();
Attribute attr = read<Attribute>();
Type type;
- if (auto typedAttr = attr.dyn_cast<TypedAttr>())
+ if (auto typedAttr = dyn_cast<TypedAttr>(attr))
type = typedAttr.getType();
LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n"