diff options
author | Tres Popp <tpopp@google.com> | 2023-05-11 11:10:46 +0200 |
---|---|---|
committer | Tres Popp <tpopp@google.com> | 2023-05-12 11:21:30 +0200 |
commit | c1fa60b4cde512964544ab66404dea79dbc5dcb4 (patch) | |
tree | 729fa03855abbe296208554fa4a7fd2dc742ab6b /mlir/test | |
parent | 5550c821897ab77e664977121a0e90ad5be1ff59 (diff) | |
download | llvm-c1fa60b4cde512964544ab66404dea79dbc5dcb4.tar.gz |
[mlir] Update method cast calls to function calls
The MLIR classes Type/Attribute/Operation/Op/Value support
cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast
functionality in addition to defining methods with the same name.
This change begins the migration of uses of the method to the
corresponding function call as has been decided as more consistent.
Note that there still exist classes that only define methods directly,
such as AffineExpr, and this does not include work currently to support
a functional cast/isa call.
Context:
* https://mlir.llvm.org/deprecation/ at "Use the free function variants for dyn_cast/cast/isa/…"
* Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443
Implementation:
This follows a previous patch that updated calls
`op.cast<T>()-> cast<T>(op)`. However some cases could not handle an
unprefixed `cast` call due to occurrences of variables named cast, or
occurring inside of class definitions which would resolve to the method.
All C++ files that did not work automatically with `cast<T>()` are
updated here to `llvm::cast` and similar with the intention that they
can be easily updated after the methods are removed through a
find-replace.
See https://github.com/llvm/llvm-project/compare/main...tpopp:llvm-project:tidy-cast-check
for the clang-tidy check that is used and then update printed
occurrences of the function to include `llvm::` before.
One can then run the following:
```
ninja -C $BUILD_DIR clang-tidy
run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\
-export-fixes /tmp/cast/casts.yaml mlir/*\
-header-filter=mlir/ -fix
rm -rf $BUILD_DIR/tools/mlir/**/*.inc
```
Differential Revision: https://reviews.llvm.org/D150348
Diffstat (limited to 'mlir/test')
-rw-r--r-- | mlir/test/lib/Dialect/Test/TestAttributes.cpp | 2 | ||||
-rw-r--r-- | mlir/test/lib/Dialect/Test/TestTypes.cpp | 17 | ||||
-rw-r--r-- | mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp | 37 | ||||
-rw-r--r-- | mlir/test/python/lib/PythonTestCAPI.cpp | 4 |
4 files changed, 31 insertions, 29 deletions
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp index e0ccd500aa90..7fc2e6ab3ec0 100644 --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -75,7 +75,7 @@ Attribute TestI64ElementsAttr::parse(AsmParser &parser, Type type) { if (parser.parseRSquare() || parser.parseGreater()) return Attribute(); return parser.getChecked<TestI64ElementsAttr>( - parser.getContext(), type.cast<ShapedType>(), elements); + parser.getContext(), llvm::cast<ShapedType>(type), elements); } void TestI64ElementsAttr::print(AsmPrinter &printer) const { diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp index 231f69f629ce..0633752067a1 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -287,18 +287,18 @@ TestTypeWithLayoutType::verifyEntries(DataLayoutEntryListRef params, for (DataLayoutEntryInterface entry : params) { // This is for testing purposes only, so assert well-formedness. assert(entry.isTypeEntry() && "unexpected identifier entry"); - assert(entry.getKey().get<Type>().isa<TestTypeWithLayoutType>() && + assert(llvm::isa<TestTypeWithLayoutType>(entry.getKey().get<Type>()) && "wrong type passed in"); - auto array = entry.getValue().dyn_cast<ArrayAttr>(); + auto array = llvm::dyn_cast<ArrayAttr>(entry.getValue()); assert(array && array.getValue().size() == 2 && "expected array of two elements"); - auto kind = array.getValue().front().dyn_cast<StringAttr>(); + auto kind = llvm::dyn_cast<StringAttr>(array.getValue().front()); (void)kind; assert(kind && (kind.getValue() == "size" || kind.getValue() == "alignment" || kind.getValue() == "preferred") && "unexpected kind"); - assert(array.getValue().back().isa<IntegerAttr>()); + assert(llvm::isa<IntegerAttr>(array.getValue().back())); } return success(); } @@ -306,10 +306,11 @@ TestTypeWithLayoutType::verifyEntries(DataLayoutEntryListRef params, unsigned TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params, StringRef expectedKind) const { for (DataLayoutEntryInterface entry : params) { - ArrayRef<Attribute> pair = entry.getValue().cast<ArrayAttr>().getValue(); - StringRef kind = pair.front().cast<StringAttr>().getValue(); + ArrayRef<Attribute> pair = + llvm::cast<ArrayAttr>(entry.getValue()).getValue(); + StringRef kind = llvm::cast<StringAttr>(pair.front()).getValue(); if (kind == expectedKind) - return pair.back().cast<IntegerAttr>().getValue().getZExtValue(); + return llvm::cast<IntegerAttr>(pair.back()).getValue().getZExtValue(); } return 1; } @@ -466,7 +467,7 @@ void TestDialect::printTestType(Type type, AsmPrinter &printer, if (succeeded(printIfDynamicType(type, printer))) return; - auto rec = type.cast<TestRecursiveType>(); + auto rec = llvm::cast<TestRecursiveType>(type); printer << "test_rec<" << rec.getName(); if (!stack.contains(rec)) { printer << ", "; diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp index 9cb26c11eb89..1d1bbc3a5708 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -109,10 +109,10 @@ DiagnosedSilenceableFailure mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply( transform::TransformResults &results, transform::TransformState &state) { if (getOperation()->getNumOperands() != 0) { - results.set(getResult().cast<OpResult>(), + results.set(llvm::cast<OpResult>(getResult()), getOperation()->getOperand(0).getDefiningOp()); } else { - results.set(getResult().cast<OpResult>(), getOperation()); + results.set(llvm::cast<OpResult>(getResult()), getOperation()); } return DiagnosedSilenceableFailure::success(); } @@ -127,7 +127,7 @@ void mlir::test::TestProduceSelfHandleOrForwardOperandOp::getEffects( DiagnosedSilenceableFailure mlir::test::TestProduceValueHandleToSelfOperand::apply( transform::TransformResults &results, transform::TransformState &state) { - results.setValues(getOut().cast<OpResult>(), getIn()); + results.setValues(llvm::cast<OpResult>(getOut()), getIn()); return DiagnosedSilenceableFailure::success(); } @@ -249,13 +249,13 @@ DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandValue::apply( for (Value value : values) { std::string note; llvm::raw_string_ostream os(note); - if (auto arg = value.dyn_cast<BlockArgument>()) { + if (auto arg = llvm::dyn_cast<BlockArgument>(value)) { os << "a block argument #" << arg.getArgNumber() << " in block #" << std::distance(arg.getOwner()->getParent()->begin(), arg.getOwner()->getIterator()) << " in region #" << arg.getOwner()->getParent()->getRegionNumber(); } else { - os << "an op result #" << value.cast<OpResult>().getResultNumber(); + os << "an op result #" << llvm::cast<OpResult>(value).getResultNumber(); } InFlightDiagnostic diag = ::emitRemark(value.getLoc()) << getMessage(); diag.attachNote() << "value handle points to " << os.str(); @@ -317,7 +317,7 @@ DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply( getOperation()))) return DiagnosedSilenceableFailure::definiteFailure(); if (getNumResults() > 0) - results.set(getResult(0).cast<OpResult>(), getOperation()); + results.set(llvm::cast<OpResult>(getResult(0)), getOperation()); return DiagnosedSilenceableFailure::success(); } @@ -339,7 +339,7 @@ mlir::test::TestReversePayloadOpsOp::apply(transform::TransformResults &results, transform::TransformState &state) { ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget()); auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps)); - results.set(getResult().cast<OpResult>(), reversedOps); + results.set(llvm::cast<OpResult>(getResult()), reversedOps); return DiagnosedSilenceableFailure::success(); } @@ -443,7 +443,8 @@ void mlir::test::TestPrintNumberOfAssociatedPayloadIROps::getEffects( DiagnosedSilenceableFailure mlir::test::TestCopyPayloadOp::apply(transform::TransformResults &results, transform::TransformState &state) { - results.set(getCopy().cast<OpResult>(), state.getPayloadOps(getHandle())); + results.set(llvm::cast<OpResult>(getCopy()), + state.getPayloadOps(getHandle())); return DiagnosedSilenceableFailure::success(); } @@ -472,7 +473,7 @@ DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload( DiagnosedSilenceableFailure mlir::transform::TestDialectParamType::checkPayload( Location loc, ArrayRef<Attribute> payload) const { for (Attribute attr : payload) { - auto integerAttr = attr.dyn_cast<IntegerAttr>(); + auto integerAttr = llvm::dyn_cast<IntegerAttr>(attr); if (integerAttr && integerAttr.getType().isSignlessInteger(32)) continue; return emitSilenceableError(loc) @@ -534,7 +535,7 @@ mlir::test::TestAddToParamOp::apply(transform::TransformResults &results, if (Value param = getParam()) { values = llvm::to_vector( llvm::map_range(state.getParams(param), [](Attribute attr) -> uint32_t { - return attr.cast<IntegerAttr>().getValue().getLimitedValue( + return llvm::cast<IntegerAttr>(attr).getValue().getLimitedValue( UINT32_MAX); })); } @@ -544,7 +545,7 @@ mlir::test::TestAddToParamOp::apply(transform::TransformResults &results, llvm::map_range(values, [this, &builder](uint32_t value) -> Attribute { return builder.getI32IntegerAttr(value + getAddendum()); })); - results.setParams(getResult().cast<OpResult>(), result); + results.setParams(llvm::cast<OpResult>(getResult()), result); return DiagnosedSilenceableFailure::success(); } @@ -562,7 +563,7 @@ mlir::test::TestProduceParamWithNumberOfTestOps::apply( }); return builder.getI32IntegerAttr(count); })); - results.setParams(getResult().cast<OpResult>(), result); + results.setParams(llvm::cast<OpResult>(getResult()), result); return DiagnosedSilenceableFailure::success(); } @@ -570,12 +571,12 @@ DiagnosedSilenceableFailure mlir::test::TestProduceIntegerParamWithTypeOp::apply( transform::TransformResults &results, transform::TransformState &state) { Attribute zero = IntegerAttr::get(getType(), 0); - results.setParams(getResult().cast<OpResult>(), zero); + results.setParams(llvm::cast<OpResult>(getResult()), zero); return DiagnosedSilenceableFailure::success(); } LogicalResult mlir::test::TestProduceIntegerParamWithTypeOp::verify() { - if (!getType().isa<IntegerType>()) { + if (!llvm::isa<IntegerType>(getType())) { return emitOpError() << "expects an integer type"; } return success(); @@ -618,7 +619,7 @@ void mlir::test::TestProduceNullPayloadOp::getEffects( DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply( transform::TransformResults &results, transform::TransformState &state) { SmallVector<Operation *, 1> null({nullptr}); - results.set(getOut().cast<OpResult>(), null); + results.set(llvm::cast<OpResult>(getOut()), null); return DiagnosedSilenceableFailure::success(); } @@ -630,7 +631,7 @@ void mlir::test::TestProduceNullParamOp::getEffects( DiagnosedSilenceableFailure mlir::test::TestProduceNullParamOp::apply(transform::TransformResults &results, transform::TransformState &state) { - results.setParams(getOut().cast<OpResult>(), Attribute()); + results.setParams(llvm::cast<OpResult>(getOut()), Attribute()); return DiagnosedSilenceableFailure::success(); } @@ -642,7 +643,7 @@ void mlir::test::TestProduceNullValueOp::getEffects( DiagnosedSilenceableFailure mlir::test::TestProduceNullValueOp::apply(transform::TransformResults &results, transform::TransformState &state) { - results.setValues(getOut().cast<OpResult>(), Value()); + results.setValues(llvm::cast<OpResult>(getOut()), Value()); return DiagnosedSilenceableFailure::success(); } @@ -662,7 +663,7 @@ void mlir::test::TestRequiredMemoryEffectsOp::getEffects( DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply( transform::TransformResults &results, transform::TransformState &state) { - results.set(getOut().cast<OpResult>(), state.getPayloadOps(getIn())); + results.set(llvm::cast<OpResult>(getOut()), state.getPayloadOps(getIn())); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/test/python/lib/PythonTestCAPI.cpp b/mlir/test/python/lib/PythonTestCAPI.cpp index 280cfa0b1738..7b443554440b 100644 --- a/mlir/test/python/lib/PythonTestCAPI.cpp +++ b/mlir/test/python/lib/PythonTestCAPI.cpp @@ -16,7 +16,7 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(PythonTest, python_test, python_test::PythonTestDialect) bool mlirAttributeIsAPythonTestTestAttribute(MlirAttribute attr) { - return unwrap(attr).isa<python_test::TestAttrAttr>(); + return llvm::isa<python_test::TestAttrAttr>(unwrap(attr)); } MlirAttribute mlirPythonTestTestAttributeGet(MlirContext context) { @@ -24,7 +24,7 @@ MlirAttribute mlirPythonTestTestAttributeGet(MlirContext context) { } bool mlirTypeIsAPythonTestTestType(MlirType type) { - return unwrap(type).isa<python_test::TestTypeType>(); + return llvm::isa<python_test::TestTypeType>(unwrap(type)); } MlirType mlirPythonTestTestTypeGet(MlirContext context) { |