summaryrefslogtreecommitdiff
path: root/mlir/test
diff options
context:
space:
mode:
authorTres Popp <tpopp@google.com>2023-05-11 11:10:46 +0200
committerTres Popp <tpopp@google.com>2023-05-12 11:21:30 +0200
commitc1fa60b4cde512964544ab66404dea79dbc5dcb4 (patch)
tree729fa03855abbe296208554fa4a7fd2dc742ab6b /mlir/test
parent5550c821897ab77e664977121a0e90ad5be1ff59 (diff)
downloadllvm-c1fa60b4cde512964544ab66404dea79dbc5dcb4.tar.gz
[mlir] Update method cast calls to function calls
The MLIR classes Type/Attribute/Operation/Op/Value support cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast functionality in addition to defining methods with the same name. This change begins the migration of uses of the method to the corresponding function call as has been decided as more consistent. Note that there still exist classes that only define methods directly, such as AffineExpr, and this does not include work currently to support a functional cast/isa call. Context: * https://mlir.llvm.org/deprecation/ at "Use the free function variants for dyn_cast/cast/isa/…" * Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443 Implementation: This follows a previous patch that updated calls `op.cast<T>()-> cast<T>(op)`. However some cases could not handle an unprefixed `cast` call due to occurrences of variables named cast, or occurring inside of class definitions which would resolve to the method. All C++ files that did not work automatically with `cast<T>()` are updated here to `llvm::cast` and similar with the intention that they can be easily updated after the methods are removed through a find-replace. See https://github.com/llvm/llvm-project/compare/main...tpopp:llvm-project:tidy-cast-check for the clang-tidy check that is used and then update printed occurrences of the function to include `llvm::` before. One can then run the following: ``` ninja -C $BUILD_DIR clang-tidy run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\ -export-fixes /tmp/cast/casts.yaml mlir/*\ -header-filter=mlir/ -fix rm -rf $BUILD_DIR/tools/mlir/**/*.inc ``` Differential Revision: https://reviews.llvm.org/D150348
Diffstat (limited to 'mlir/test')
-rw-r--r--mlir/test/lib/Dialect/Test/TestAttributes.cpp2
-rw-r--r--mlir/test/lib/Dialect/Test/TestTypes.cpp17
-rw-r--r--mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp37
-rw-r--r--mlir/test/python/lib/PythonTestCAPI.cpp4
4 files changed, 31 insertions, 29 deletions
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index e0ccd500aa90..7fc2e6ab3ec0 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -75,7 +75,7 @@ Attribute TestI64ElementsAttr::parse(AsmParser &parser, Type type) {
if (parser.parseRSquare() || parser.parseGreater())
return Attribute();
return parser.getChecked<TestI64ElementsAttr>(
- parser.getContext(), type.cast<ShapedType>(), elements);
+ parser.getContext(), llvm::cast<ShapedType>(type), elements);
}
void TestI64ElementsAttr::print(AsmPrinter &printer) const {
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 231f69f629ce..0633752067a1 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -287,18 +287,18 @@ TestTypeWithLayoutType::verifyEntries(DataLayoutEntryListRef params,
for (DataLayoutEntryInterface entry : params) {
// This is for testing purposes only, so assert well-formedness.
assert(entry.isTypeEntry() && "unexpected identifier entry");
- assert(entry.getKey().get<Type>().isa<TestTypeWithLayoutType>() &&
+ assert(llvm::isa<TestTypeWithLayoutType>(entry.getKey().get<Type>()) &&
"wrong type passed in");
- auto array = entry.getValue().dyn_cast<ArrayAttr>();
+ auto array = llvm::dyn_cast<ArrayAttr>(entry.getValue());
assert(array && array.getValue().size() == 2 &&
"expected array of two elements");
- auto kind = array.getValue().front().dyn_cast<StringAttr>();
+ auto kind = llvm::dyn_cast<StringAttr>(array.getValue().front());
(void)kind;
assert(kind &&
(kind.getValue() == "size" || kind.getValue() == "alignment" ||
kind.getValue() == "preferred") &&
"unexpected kind");
- assert(array.getValue().back().isa<IntegerAttr>());
+ assert(llvm::isa<IntegerAttr>(array.getValue().back()));
}
return success();
}
@@ -306,10 +306,11 @@ TestTypeWithLayoutType::verifyEntries(DataLayoutEntryListRef params,
unsigned TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params,
StringRef expectedKind) const {
for (DataLayoutEntryInterface entry : params) {
- ArrayRef<Attribute> pair = entry.getValue().cast<ArrayAttr>().getValue();
- StringRef kind = pair.front().cast<StringAttr>().getValue();
+ ArrayRef<Attribute> pair =
+ llvm::cast<ArrayAttr>(entry.getValue()).getValue();
+ StringRef kind = llvm::cast<StringAttr>(pair.front()).getValue();
if (kind == expectedKind)
- return pair.back().cast<IntegerAttr>().getValue().getZExtValue();
+ return llvm::cast<IntegerAttr>(pair.back()).getValue().getZExtValue();
}
return 1;
}
@@ -466,7 +467,7 @@ void TestDialect::printTestType(Type type, AsmPrinter &printer,
if (succeeded(printIfDynamicType(type, printer)))
return;
- auto rec = type.cast<TestRecursiveType>();
+ auto rec = llvm::cast<TestRecursiveType>(type);
printer << "test_rec<" << rec.getName();
if (!stack.contains(rec)) {
printer << ", ";
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 9cb26c11eb89..1d1bbc3a5708 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -109,10 +109,10 @@ DiagnosedSilenceableFailure
mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
if (getOperation()->getNumOperands() != 0) {
- results.set(getResult().cast<OpResult>(),
+ results.set(llvm::cast<OpResult>(getResult()),
getOperation()->getOperand(0).getDefiningOp());
} else {
- results.set(getResult().cast<OpResult>(), getOperation());
+ results.set(llvm::cast<OpResult>(getResult()), getOperation());
}
return DiagnosedSilenceableFailure::success();
}
@@ -127,7 +127,7 @@ void mlir::test::TestProduceSelfHandleOrForwardOperandOp::getEffects(
DiagnosedSilenceableFailure
mlir::test::TestProduceValueHandleToSelfOperand::apply(
transform::TransformResults &results, transform::TransformState &state) {
- results.setValues(getOut().cast<OpResult>(), getIn());
+ results.setValues(llvm::cast<OpResult>(getOut()), getIn());
return DiagnosedSilenceableFailure::success();
}
@@ -249,13 +249,13 @@ DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandValue::apply(
for (Value value : values) {
std::string note;
llvm::raw_string_ostream os(note);
- if (auto arg = value.dyn_cast<BlockArgument>()) {
+ if (auto arg = llvm::dyn_cast<BlockArgument>(value)) {
os << "a block argument #" << arg.getArgNumber() << " in block #"
<< std::distance(arg.getOwner()->getParent()->begin(),
arg.getOwner()->getIterator())
<< " in region #" << arg.getOwner()->getParent()->getRegionNumber();
} else {
- os << "an op result #" << value.cast<OpResult>().getResultNumber();
+ os << "an op result #" << llvm::cast<OpResult>(value).getResultNumber();
}
InFlightDiagnostic diag = ::emitRemark(value.getLoc()) << getMessage();
diag.attachNote() << "value handle points to " << os.str();
@@ -317,7 +317,7 @@ DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply(
getOperation())))
return DiagnosedSilenceableFailure::definiteFailure();
if (getNumResults() > 0)
- results.set(getResult(0).cast<OpResult>(), getOperation());
+ results.set(llvm::cast<OpResult>(getResult(0)), getOperation());
return DiagnosedSilenceableFailure::success();
}
@@ -339,7 +339,7 @@ mlir::test::TestReversePayloadOpsOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps));
- results.set(getResult().cast<OpResult>(), reversedOps);
+ results.set(llvm::cast<OpResult>(getResult()), reversedOps);
return DiagnosedSilenceableFailure::success();
}
@@ -443,7 +443,8 @@ void mlir::test::TestPrintNumberOfAssociatedPayloadIROps::getEffects(
DiagnosedSilenceableFailure
mlir::test::TestCopyPayloadOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
- results.set(getCopy().cast<OpResult>(), state.getPayloadOps(getHandle()));
+ results.set(llvm::cast<OpResult>(getCopy()),
+ state.getPayloadOps(getHandle()));
return DiagnosedSilenceableFailure::success();
}
@@ -472,7 +473,7 @@ DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload(
DiagnosedSilenceableFailure mlir::transform::TestDialectParamType::checkPayload(
Location loc, ArrayRef<Attribute> payload) const {
for (Attribute attr : payload) {
- auto integerAttr = attr.dyn_cast<IntegerAttr>();
+ auto integerAttr = llvm::dyn_cast<IntegerAttr>(attr);
if (integerAttr && integerAttr.getType().isSignlessInteger(32))
continue;
return emitSilenceableError(loc)
@@ -534,7 +535,7 @@ mlir::test::TestAddToParamOp::apply(transform::TransformResults &results,
if (Value param = getParam()) {
values = llvm::to_vector(
llvm::map_range(state.getParams(param), [](Attribute attr) -> uint32_t {
- return attr.cast<IntegerAttr>().getValue().getLimitedValue(
+ return llvm::cast<IntegerAttr>(attr).getValue().getLimitedValue(
UINT32_MAX);
}));
}
@@ -544,7 +545,7 @@ mlir::test::TestAddToParamOp::apply(transform::TransformResults &results,
llvm::map_range(values, [this, &builder](uint32_t value) -> Attribute {
return builder.getI32IntegerAttr(value + getAddendum());
}));
- results.setParams(getResult().cast<OpResult>(), result);
+ results.setParams(llvm::cast<OpResult>(getResult()), result);
return DiagnosedSilenceableFailure::success();
}
@@ -562,7 +563,7 @@ mlir::test::TestProduceParamWithNumberOfTestOps::apply(
});
return builder.getI32IntegerAttr(count);
}));
- results.setParams(getResult().cast<OpResult>(), result);
+ results.setParams(llvm::cast<OpResult>(getResult()), result);
return DiagnosedSilenceableFailure::success();
}
@@ -570,12 +571,12 @@ DiagnosedSilenceableFailure
mlir::test::TestProduceIntegerParamWithTypeOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
Attribute zero = IntegerAttr::get(getType(), 0);
- results.setParams(getResult().cast<OpResult>(), zero);
+ results.setParams(llvm::cast<OpResult>(getResult()), zero);
return DiagnosedSilenceableFailure::success();
}
LogicalResult mlir::test::TestProduceIntegerParamWithTypeOp::verify() {
- if (!getType().isa<IntegerType>()) {
+ if (!llvm::isa<IntegerType>(getType())) {
return emitOpError() << "expects an integer type";
}
return success();
@@ -618,7 +619,7 @@ void mlir::test::TestProduceNullPayloadOp::getEffects(
DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
SmallVector<Operation *, 1> null({nullptr});
- results.set(getOut().cast<OpResult>(), null);
+ results.set(llvm::cast<OpResult>(getOut()), null);
return DiagnosedSilenceableFailure::success();
}
@@ -630,7 +631,7 @@ void mlir::test::TestProduceNullParamOp::getEffects(
DiagnosedSilenceableFailure
mlir::test::TestProduceNullParamOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
- results.setParams(getOut().cast<OpResult>(), Attribute());
+ results.setParams(llvm::cast<OpResult>(getOut()), Attribute());
return DiagnosedSilenceableFailure::success();
}
@@ -642,7 +643,7 @@ void mlir::test::TestProduceNullValueOp::getEffects(
DiagnosedSilenceableFailure
mlir::test::TestProduceNullValueOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
- results.setValues(getOut().cast<OpResult>(), Value());
+ results.setValues(llvm::cast<OpResult>(getOut()), Value());
return DiagnosedSilenceableFailure::success();
}
@@ -662,7 +663,7 @@ void mlir::test::TestRequiredMemoryEffectsOp::getEffects(
DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
- results.set(getOut().cast<OpResult>(), state.getPayloadOps(getIn()));
+ results.set(llvm::cast<OpResult>(getOut()), state.getPayloadOps(getIn()));
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/test/python/lib/PythonTestCAPI.cpp b/mlir/test/python/lib/PythonTestCAPI.cpp
index 280cfa0b1738..7b443554440b 100644
--- a/mlir/test/python/lib/PythonTestCAPI.cpp
+++ b/mlir/test/python/lib/PythonTestCAPI.cpp
@@ -16,7 +16,7 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(PythonTest, python_test,
python_test::PythonTestDialect)
bool mlirAttributeIsAPythonTestTestAttribute(MlirAttribute attr) {
- return unwrap(attr).isa<python_test::TestAttrAttr>();
+ return llvm::isa<python_test::TestAttrAttr>(unwrap(attr));
}
MlirAttribute mlirPythonTestTestAttributeGet(MlirContext context) {
@@ -24,7 +24,7 @@ MlirAttribute mlirPythonTestTestAttributeGet(MlirContext context) {
}
bool mlirTypeIsAPythonTestTestType(MlirType type) {
- return unwrap(type).isa<python_test::TestTypeType>();
+ return llvm::isa<python_test::TestTypeType>(unwrap(type));
}
MlirType mlirPythonTestTestTypeGet(MlirContext context) {