summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp61
-rw-r--r--mlir/lib/Dialect/IRDL/IRDLLoading.cpp117
2 files changed, 167 insertions, 11 deletions
diff --git a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
index a9956cc630cc..c0e839720200 100644
--- a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
+++ b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
@@ -12,16 +12,18 @@ using namespace mlir;
using namespace mlir::irdl;
std::unique_ptr<Constraint> Is::getVerifier(
- SmallVector<Value> const &valueToConstr,
- DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
- DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
+ ArrayRef<Value> valueToConstr,
+ DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> const &types,
+ DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> const
+ &attrs) {
return std::make_unique<IsConstraint>(getExpectedAttr());
}
std::unique_ptr<Constraint> Parametric::getVerifier(
- SmallVector<Value> const &valueToConstr,
- DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
- DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
+ ArrayRef<Value> valueToConstr,
+ DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> const &types,
+ DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> const
+ &attrs) {
SmallVector<unsigned> constraints;
for (Value arg : getArgs()) {
for (auto [i, value] : enumerate(valueToConstr)) {
@@ -42,20 +44,57 @@ std::unique_ptr<Constraint> Parametric::getVerifier(
}
if (auto typeOp = dyn_cast<TypeOp>(defOp))
- return std::make_unique<DynParametricTypeConstraint>(types[typeOp].get(),
+ return std::make_unique<DynParametricTypeConstraint>(types.at(typeOp).get(),
constraints);
if (auto attrOp = dyn_cast<AttributeOp>(defOp))
- return std::make_unique<DynParametricAttrConstraint>(attrs[attrOp].get(),
+ return std::make_unique<DynParametricAttrConstraint>(attrs.at(attrOp).get(),
constraints);
llvm_unreachable("verifier should ensure that the referenced operation is "
"either a type or an attribute definition");
}
+std::unique_ptr<Constraint> AnyOf::getVerifier(
+ ArrayRef<Value> valueToConstr,
+ DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> const &types,
+ DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> const
+ &attrs) {
+ SmallVector<unsigned> constraints;
+ for (Value arg : getArgs()) {
+ for (auto [i, value] : enumerate(valueToConstr)) {
+ if (value == arg) {
+ constraints.push_back(i);
+ break;
+ }
+ }
+ }
+
+ return std::make_unique<AnyOfConstraint>(constraints);
+}
+
+std::unique_ptr<Constraint> AllOf::getVerifier(
+ ArrayRef<Value> valueToConstr,
+ DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> const &types,
+ DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> const
+ &attrs) {
+ SmallVector<unsigned> constraints;
+ for (Value arg : getArgs()) {
+ for (auto [i, value] : enumerate(valueToConstr)) {
+ if (value == arg) {
+ constraints.push_back(i);
+ break;
+ }
+ }
+ }
+
+ return std::make_unique<AllOfConstraint>(constraints);
+}
+
std::unique_ptr<Constraint> Any::getVerifier(
- SmallVector<Value> const &valueToConstr,
- DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
- DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
+ ArrayRef<Value> valueToConstr,
+ DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> const &types,
+ DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> const
+ &attrs) {
return std::make_unique<AnyAttributeConstraint>();
}
diff --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
index f65d0eceb03a..07f0e4e5e443 100644
--- a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
+++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
@@ -239,6 +239,116 @@ static DynamicAttrDefinition::VerifierFn getAttrOrTypeVerifier(
return std::move(verifier);
}
+/// Get the possible bases of a constraint. Return `true` if all bases can
+/// potentially be matched.
+/// A base is a type or an attribute definition. For instance, the base of
+/// `irdl.parametric "!builtin.complex"(...)` is `builtin.complex`.
+/// This function returns the following information through arguments:
+/// - `paramIds`: the set of type or attribute IDs that are used as bases.
+/// - `paramIrdlOps`: the set of IRDL operations that are used as bases.
+/// - `isIds`: the set of type or attribute IDs that are used in `irdl.is`
+/// constraints.
+static bool getBases(Operation *op, SmallPtrSet<TypeID, 4> &paramIds,
+ SmallPtrSet<Operation *, 4> &paramIrdlOps,
+ SmallPtrSet<TypeID, 4> &isIds) {
+ // For `irdl.any_of`, we get the bases from all its arguments.
+ if (auto anyOf = dyn_cast<AnyOf>(op)) {
+ bool has_any = false;
+ for (Value arg : anyOf.getArgs())
+ has_any &= getBases(arg.getDefiningOp(), paramIds, paramIrdlOps, isIds);
+ return has_any;
+ }
+
+ // For `irdl.all_of`, we get the bases from the first argument.
+ // This is restrictive, but we can relax it later if needed.
+ if (auto allOf = dyn_cast<AllOf>(op))
+ return getBases(allOf.getArgs()[0].getDefiningOp(), paramIds, paramIrdlOps,
+ isIds);
+
+ // For `irdl.parametric`, we get directly the base from the operation.
+ if (auto params = dyn_cast<Parametric>(op)) {
+ SymbolRefAttr symRef = params.getBaseType();
+ Operation *defOp = SymbolTable::lookupNearestSymbolFrom(op, symRef);
+ assert(defOp && "symbol reference should refer to an existing operation");
+ paramIrdlOps.insert(defOp);
+ return false;
+ }
+
+ // For `irdl.is`, we get the base TypeID directly.
+ if (auto is = dyn_cast<Is>(op)) {
+ Attribute expected = is.getExpected();
+ isIds.insert(expected.getTypeID());
+ return false;
+ }
+
+ // For `irdl.any`, we return `false` since we can match any type or attribute
+ // base.
+ if (auto isA = dyn_cast<Any>(op))
+ return true;
+
+ llvm_unreachable("unknown IRDL constraint");
+}
+
+/// Check that an any_of is in the subset IRDL can handle.
+/// IRDL uses a greedy algorithm to match constraints. This means that if we
+/// encounter an `any_of` with multiple constraints, we will match the first
+/// constraint that is satisfied. Thus, the order of constraints matter in
+/// `any_of` with our current algorithm.
+/// In order to make the order of constraints irrelevant, we require that
+/// all `any_of` constraint parameters are disjoint. For this, we check that
+/// the base parameters are all disjoints between `parametric` operations, and
+/// that they are disjoint between `parametric` and `is` operations.
+/// This restriction will be relaxed in the future, when we will change our
+/// algorithm to be non-greedy.
+static LogicalResult checkCorrectAnyOf(AnyOf anyOf) {
+ SmallPtrSet<TypeID, 4> paramIds;
+ SmallPtrSet<Operation *, 4> paramIrdlOps;
+ SmallPtrSet<TypeID, 4> isIds;
+
+ for (Value arg : anyOf.getArgs()) {
+ Operation *argOp = arg.getDefiningOp();
+ SmallPtrSet<TypeID, 4> argParamIds;
+ SmallPtrSet<Operation *, 4> argParamIrdlOps;
+ SmallPtrSet<TypeID, 4> argIsIds;
+
+ // Get the bases of this argument. If it can match any type or attribute,
+ // then our `any_of` should not be allowed.
+ if (getBases(argOp, argParamIds, argParamIrdlOps, argIsIds))
+ return failure();
+
+ // We check that the base parameters are all disjoints between `parametric`
+ // operations, and that they are disjoint between `parametric` and `is`
+ // operations.
+ for (TypeID id : argParamIds) {
+ if (isIds.count(id))
+ return failure();
+ bool inserted = paramIds.insert(id).second;
+ if (!inserted)
+ return failure();
+ }
+
+ // We check that the base parameters are all disjoints with `irdl.is`
+ // operations.
+ for (TypeID id : isIds) {
+ if (paramIds.count(id))
+ return failure();
+ isIds.insert(id);
+ }
+
+ // We check that all `parametric` operations are disjoint. We do not
+ // need to check that they are disjoint with `is` operations, since
+ // `is` operations cannot refer to attributes defined with `irdl.parametric`
+ // operations.
+ for (Operation *op : argParamIrdlOps) {
+ bool inserted = paramIrdlOps.insert(op).second;
+ if (!inserted)
+ return failure();
+ }
+ }
+
+ return success();
+}
+
/// Load all dialects in the given module, without loading any operation, type
/// or attribute definitions.
static DenseMap<DialectOp, ExtensibleDialect *> loadEmptyDialects(ModuleOp op) {
@@ -292,6 +402,13 @@ preallocateAttrDefs(ModuleOp op,
}
LogicalResult mlir::irdl::loadDialects(ModuleOp op) {
+ // First, check that all any_of constraints are in a correct form.
+ // This is to ensure we can do the verification correctly.
+ WalkResult anyOfCorrects =
+ op.walk([](AnyOf anyOf) { return (WalkResult)checkCorrectAnyOf(anyOf); });
+ if (anyOfCorrects.wasInterrupted())
+ return op.emitError("any_of constraints are not in the correct form");
+
// Preallocate all dialects, and type and attribute definitions.
// In particular, this allocates TypeIDs so type and attributes can have
// verifiers that refer to each other.