summaryrefslogtreecommitdiff
path: root/mlir/include
diff options
context:
space:
mode:
authorTres Popp <tpopp@google.com>2023-05-11 11:10:46 +0200
committerTres Popp <tpopp@google.com>2023-05-12 11:21:30 +0200
commitc1fa60b4cde512964544ab66404dea79dbc5dcb4 (patch)
tree729fa03855abbe296208554fa4a7fd2dc742ab6b /mlir/include
parent5550c821897ab77e664977121a0e90ad5be1ff59 (diff)
downloadllvm-c1fa60b4cde512964544ab66404dea79dbc5dcb4.tar.gz
[mlir] Update method cast calls to function calls
The MLIR classes Type/Attribute/Operation/Op/Value support cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast functionality in addition to defining methods with the same name. This change begins the migration of uses of the method to the corresponding function call as has been decided as more consistent. Note that there still exist classes that only define methods directly, such as AffineExpr, and this does not include work currently to support a functional cast/isa call. Context: * https://mlir.llvm.org/deprecation/ at "Use the free function variants for dyn_cast/cast/isa/…" * Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443 Implementation: This follows a previous patch that updated calls `op.cast<T>()-> cast<T>(op)`. However some cases could not handle an unprefixed `cast` call due to occurrences of variables named cast, or occurring inside of class definitions which would resolve to the method. All C++ files that did not work automatically with `cast<T>()` are updated here to `llvm::cast` and similar with the intention that they can be easily updated after the methods are removed through a find-replace. See https://github.com/llvm/llvm-project/compare/main...tpopp:llvm-project:tidy-cast-check for the clang-tidy check that is used and then update printed occurrences of the function to include `llvm::` before. One can then run the following: ``` ninja -C $BUILD_DIR clang-tidy run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\ -export-fixes /tmp/cast/casts.yaml mlir/*\ -header-filter=mlir/ -fix rm -rf $BUILD_DIR/tools/mlir/**/*.inc ``` Differential Revision: https://reviews.llvm.org/D150348
Diffstat (limited to 'mlir/include')
-rw-r--r--mlir/include/mlir/IR/Builders.h2
-rw-r--r--mlir/include/mlir/IR/BuiltinAttributes.h22
-rw-r--r--mlir/include/mlir/IR/BuiltinTypes.h17
-rw-r--r--mlir/include/mlir/IR/FunctionInterfaces.h4
-rw-r--r--mlir/include/mlir/IR/Location.h4
-rw-r--r--mlir/include/mlir/IR/Matchers.h16
-rw-r--r--mlir/include/mlir/IR/OpImplementation.h16
-rw-r--r--mlir/include/mlir/IR/Operation.h4
-rw-r--r--mlir/include/mlir/IR/Value.h2
9 files changed, 43 insertions, 44 deletions
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 1b0f4bfb3f62..4dbeb418099d 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -403,7 +403,7 @@ public:
if (Operation *op = val.getDefiningOp()) {
setInsertionPointAfter(op);
} else {
- auto blockArg = val.cast<BlockArgument>();
+ auto blockArg = llvm::cast<BlockArgument>(val);
setInsertionPointToStart(blockArg.getOwner());
}
}
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 496c197e4715..7c4136021cb5 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -389,7 +389,7 @@ public:
!std::is_same<Attribute, T>::value,
T>
getSplatValue() const {
- return getSplatValue<Attribute>().template cast<T>();
+ return llvm::cast<T>(getSplatValue<Attribute>());
}
/// Try to get an iterator of the given type to the start of the held element
@@ -510,7 +510,7 @@ public:
T>::mapped_iterator_base;
/// Map the element to the iterator result type.
- T mapElement(Attribute attr) const { return attr.cast<T>(); }
+ T mapElement(Attribute attr) const { return llvm::cast<T>(attr); }
};
template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
FailureOr<iterator_range_impl<DerivedAttributeElementIterator<T>>>
@@ -684,7 +684,7 @@ public:
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(Attribute attr) {
- auto denseAttr = attr.dyn_cast<DenseElementsAttr>();
+ auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(attr);
return denseAttr && denseAttr.isSplat();
}
};
@@ -887,7 +887,7 @@ public:
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Attribute attr) {
- SymbolRefAttr refAttr = attr.dyn_cast<SymbolRefAttr>();
+ SymbolRefAttr refAttr = llvm::dyn_cast<SymbolRefAttr>(attr);
return refAttr && refAttr.getNestedReferences().empty();
}
@@ -912,14 +912,13 @@ public:
/// simply wraps the DenseElementsAttr::get calls.
template <typename Arg>
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg) {
- return DenseElementsAttr::get(type, llvm::ArrayRef(arg))
- .template cast<DenseFPElementsAttr>();
+ return llvm::cast<DenseFPElementsAttr>(
+ DenseElementsAttr::get(type, llvm::ArrayRef(arg)));
}
template <typename T>
static DenseFPElementsAttr get(const ShapedType &type,
const std::initializer_list<T> &list) {
- return DenseElementsAttr::get(type, list)
- .template cast<DenseFPElementsAttr>();
+ return llvm::cast<DenseFPElementsAttr>(DenseElementsAttr::get(type, list));
}
/// Generates a new DenseElementsAttr by mapping each value attribute, and
@@ -954,14 +953,13 @@ public:
/// simply wraps the DenseElementsAttr::get calls.
template <typename Arg>
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg) {
- return DenseElementsAttr::get(type, llvm::ArrayRef(arg))
- .template cast<DenseIntElementsAttr>();
+ return llvm::cast<DenseIntElementsAttr>(
+ DenseElementsAttr::get(type, llvm::ArrayRef(arg)));
}
template <typename T>
static DenseIntElementsAttr get(const ShapedType &type,
const std::initializer_list<T> &list) {
- return DenseElementsAttr::get(type, list)
- .template cast<DenseIntElementsAttr>();
+ return llvm::cast<DenseIntElementsAttr>(DenseElementsAttr::get(type, list));
}
/// Generates a new DenseElementsAttr by mapping each value attribute, and
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index f26465ef1d66..4fc82dd7a8e9 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -367,20 +367,21 @@ SliceVerificationResult isRankReducedType(ShapedType originalType,
//===----------------------------------------------------------------------===//
inline bool BaseMemRefType::classof(Type type) {
- return type.isa<MemRefType, UnrankedMemRefType>();
+ return llvm::isa<MemRefType, UnrankedMemRefType>(type);
}
inline bool BaseMemRefType::isValidElementType(Type type) {
return type.isIntOrIndexOrFloat() ||
- type.isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>() ||
- type.isa<MemRefElementTypeInterface>();
+ llvm::isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>(
+ type) ||
+ llvm::isa<MemRefElementTypeInterface>(type);
}
inline bool FloatType::classof(Type type) {
- return type
- .isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
- Float8E4M3FNUZType, Float8E4M3B11FNUZType, BFloat16Type, Float16Type,
- Float32Type, Float64Type, Float80Type, Float128Type>();
+ return llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
+ Float8E4M3FNUZType, Float8E4M3B11FNUZType, BFloat16Type,
+ Float16Type, Float32Type, Float64Type, Float80Type,
+ Float128Type>(type);
}
inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
@@ -428,7 +429,7 @@ inline FloatType FloatType::getF128(MLIRContext *ctx) {
}
inline bool TensorType::classof(Type type) {
- return type.isa<RankedTensorType, UnrankedTensorType>();
+ return llvm::isa<RankedTensorType, UnrankedTensorType>(type);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/FunctionInterfaces.h b/mlir/include/mlir/IR/FunctionInterfaces.h
index 3beb3db4e566..e813cb8f0390 100644
--- a/mlir/include/mlir/IR/FunctionInterfaces.h
+++ b/mlir/include/mlir/IR/FunctionInterfaces.h
@@ -178,7 +178,7 @@ LogicalResult verifyTrait(ConcreteOp op) {
}
for (unsigned i = 0; i != numArgs; ++i) {
DictionaryAttr argAttrs =
- allArgAttrs[i].dyn_cast_or_null<DictionaryAttr>();
+ llvm::dyn_cast_or_null<DictionaryAttr>(allArgAttrs[i]);
if (!argAttrs) {
return op.emitOpError() << "expects argument attribute dictionary "
"to be a DictionaryAttr, but got `"
@@ -209,7 +209,7 @@ LogicalResult verifyTrait(ConcreteOp op) {
}
for (unsigned i = 0; i != numResults; ++i) {
DictionaryAttr resultAttrs =
- allResultAttrs[i].dyn_cast_or_null<DictionaryAttr>();
+ llvm::dyn_cast_or_null<DictionaryAttr>(allResultAttrs[i]);
if (!resultAttrs) {
return op.emitOpError() << "expects result attribute dictionary "
"to be a DictionaryAttr, but got `"
diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h
index 63b12899e249..d4268e804f4f 100644
--- a/mlir/include/mlir/IR/Location.h
+++ b/mlir/include/mlir/IR/Location.h
@@ -148,12 +148,12 @@ public:
/// Return the metadata associated with this fused location.
MetadataT getMetadata() const {
- return FusedLoc::getMetadata().template cast<MetadataT>();
+ return llvm::cast<MetadataT>(FusedLoc::getMetadata());
}
/// Support llvm style casting.
static bool classof(Attribute attr) {
- auto fusedLoc = attr.dyn_cast<FusedLoc>();
+ auto fusedLoc = llvm::dyn_cast<FusedLoc>(attr);
return fusedLoc && fusedLoc.getMetadata().isa_and_nonnull<MetadataT>();
}
};
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 4dbc623916ac..2361a541efc2 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -39,7 +39,7 @@ struct attr_value_binder {
attr_value_binder(ValueType *bv) : bind_value(bv) {}
bool match(const Attribute &attr) {
- if (auto intAttr = attr.dyn_cast<AttrClass>()) {
+ if (auto intAttr = llvm::dyn_cast<AttrClass>(attr)) {
*bind_value = intAttr.getValue();
return true;
}
@@ -90,7 +90,7 @@ struct constant_op_binder {
(void)result;
assert(succeeded(result) && "expected ConstantLike op to be foldable");
- if (auto attr = foldedOp.front().get<Attribute>().dyn_cast<AttrT>()) {
+ if (auto attr = llvm::dyn_cast<AttrT>(foldedOp.front().get<Attribute>())) {
if (bind_value)
*bind_value = attr;
return true;
@@ -136,10 +136,10 @@ struct constant_float_op_binder {
return false;
auto type = op->getResult(0).getType();
- if (type.isa<FloatType>())
+ if (llvm::isa<FloatType>(type))
return attr_value_binder<FloatAttr>(bind_value).match(attr);
- if (type.isa<VectorType, RankedTensorType>()) {
- if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
+ if (llvm::isa<VectorType, RankedTensorType>(type)) {
+ if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(attr)) {
return attr_value_binder<FloatAttr>(bind_value)
.match(splatAttr.getSplatValue<Attribute>());
}
@@ -173,10 +173,10 @@ struct constant_int_op_binder {
return false;
auto type = op->getResult(0).getType();
- if (type.isa<IntegerType, IndexType>())
+ if (llvm::isa<IntegerType, IndexType>(type))
return attr_value_binder<IntegerAttr>(bind_value).match(attr);
- if (type.isa<VectorType, RankedTensorType>()) {
- if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
+ if (llvm::isa<VectorType, RankedTensorType>(type)) {
+ if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(attr)) {
return attr_value_binder<IntegerAttr>(bind_value)
.match(splatAttr.getSplatValue<Attribute>());
}
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index f4045b236a44..4c36453f31b2 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -204,7 +204,7 @@ public:
auto &os = getStream() << " -> ";
bool wrapped = !llvm::hasSingleElement(types) ||
- (*types.begin()).template isa<FunctionType>();
+ llvm::isa<FunctionType>((*types.begin()));
if (wrapped)
os << '(';
llvm::interleaveComma(types, *this);
@@ -865,7 +865,7 @@ public:
return failure();
// Check for the right kind of attribute.
- if (!(result = attr.dyn_cast<AttrType>()))
+ if (!(result = llvm::dyn_cast<AttrType>(attr)))
return emitError(loc, "invalid kind of attribute specified");
return success();
@@ -899,7 +899,7 @@ public:
return failure();
// Check for the right kind of attribute.
- result = attr.dyn_cast<AttrType>();
+ result = llvm::dyn_cast<AttrType>(attr);
if (!result)
return emitError(loc, "invalid kind of attribute specified");
@@ -936,7 +936,7 @@ public:
return failure();
// Check for the right kind of attribute.
- result = attr.dyn_cast<AttrType>();
+ result = llvm::dyn_cast<AttrType>(attr);
if (!result)
return emitError(loc, "invalid kind of attribute specified");
@@ -970,7 +970,7 @@ public:
return failure();
// Check for the right kind of attribute.
- result = attr.dyn_cast<AttrType>();
+ result = llvm::dyn_cast<AttrType>(attr);
if (!result)
return emitError(loc, "invalid kind of attribute specified");
return success();
@@ -1126,7 +1126,7 @@ public:
return failure();
// Check for the right kind of type.
- result = type.dyn_cast<TypeT>();
+ result = llvm::dyn_cast<TypeT>(type);
if (!result)
return emitError(loc, "invalid kind of type specified");
@@ -1158,7 +1158,7 @@ public:
return failure();
// Check for the right kind of Type.
- result = type.dyn_cast<TypeT>();
+ result = llvm::dyn_cast<TypeT>(type);
if (!result)
return emitError(loc, "invalid kind of Type specified");
return success();
@@ -1198,7 +1198,7 @@ public:
return failure();
// Check for the right kind of type.
- result = type.dyn_cast<TypeType>();
+ result = llvm::dyn_cast<TypeType>(type);
if (!result)
return emitError(loc, "invalid kind of type specified");
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 8bd23a2a1fc5..ec6d4ca2d6e6 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -509,11 +509,11 @@ public:
template <typename AttrClass>
AttrClass getAttrOfType(StringAttr name) {
- return getAttr(name).dyn_cast_or_null<AttrClass>();
+ return llvm::dyn_cast_or_null<AttrClass>(getAttr(name));
}
template <typename AttrClass>
AttrClass getAttrOfType(StringRef name) {
- return getAttr(name).dyn_cast_or_null<AttrClass>();
+ return llvm::dyn_cast_or_null<AttrClass>(getAttr(name));
}
/// Return true if the operation has an attribute with the provided name,
diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index 7a8aee29ca44..a280fbdf64bc 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -433,7 +433,7 @@ struct TypedValue : Value {
static bool classof(Value value) { return llvm::isa<Ty>(value.getType()); }
/// Return the known Type
- Ty getType() { return Value::getType().template cast<Ty>(); }
+ Ty getType() { return llvm::cast<Ty>(Value::getType()); }
void setType(Ty ty) { Value::setType(ty); }
};