diff options
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r-- | mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp | 61 | ||||
-rw-r--r-- | mlir/lib/Dialect/IRDL/IRDLLoading.cpp | 117 |
2 files changed, 167 insertions, 11 deletions
diff --git a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp index a9956cc630cc..c0e839720200 100644 --- a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp +++ b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp @@ -12,16 +12,18 @@ using namespace mlir; using namespace mlir::irdl; std::unique_ptr<Constraint> Is::getVerifier( - SmallVector<Value> const &valueToConstr, - DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types, - DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) { + ArrayRef<Value> valueToConstr, + DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> const &types, + DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> const + &attrs) { return std::make_unique<IsConstraint>(getExpectedAttr()); } std::unique_ptr<Constraint> Parametric::getVerifier( - SmallVector<Value> const &valueToConstr, - DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types, - DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) { + ArrayRef<Value> valueToConstr, + DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> const &types, + DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> const + &attrs) { SmallVector<unsigned> constraints; for (Value arg : getArgs()) { for (auto [i, value] : enumerate(valueToConstr)) { @@ -42,20 +44,57 @@ std::unique_ptr<Constraint> Parametric::getVerifier( } if (auto typeOp = dyn_cast<TypeOp>(defOp)) - return std::make_unique<DynParametricTypeConstraint>(types[typeOp].get(), + return std::make_unique<DynParametricTypeConstraint>(types.at(typeOp).get(), constraints); if (auto attrOp = dyn_cast<AttributeOp>(defOp)) - return std::make_unique<DynParametricAttrConstraint>(attrs[attrOp].get(), + return std::make_unique<DynParametricAttrConstraint>(attrs.at(attrOp).get(), constraints); llvm_unreachable("verifier should ensure that the referenced operation is " "either a type or an attribute definition"); } +std::unique_ptr<Constraint> AnyOf::getVerifier( + ArrayRef<Value> valueToConstr, + DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> const &types, + DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> const + &attrs) { + SmallVector<unsigned> constraints; + for (Value arg : getArgs()) { + for (auto [i, value] : enumerate(valueToConstr)) { + if (value == arg) { + constraints.push_back(i); + break; + } + } + } + + return std::make_unique<AnyOfConstraint>(constraints); +} + +std::unique_ptr<Constraint> AllOf::getVerifier( + ArrayRef<Value> valueToConstr, + DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> const &types, + DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> const + &attrs) { + SmallVector<unsigned> constraints; + for (Value arg : getArgs()) { + for (auto [i, value] : enumerate(valueToConstr)) { + if (value == arg) { + constraints.push_back(i); + break; + } + } + } + + return std::make_unique<AllOfConstraint>(constraints); +} + std::unique_ptr<Constraint> Any::getVerifier( - SmallVector<Value> const &valueToConstr, - DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types, - DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) { + ArrayRef<Value> valueToConstr, + DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> const &types, + DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> const + &attrs) { return std::make_unique<AnyAttributeConstraint>(); } diff --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp index f65d0eceb03a..07f0e4e5e443 100644 --- a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp +++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp @@ -239,6 +239,116 @@ static DynamicAttrDefinition::VerifierFn getAttrOrTypeVerifier( return std::move(verifier); } +/// Get the possible bases of a constraint. Return `true` if all bases can +/// potentially be matched. +/// A base is a type or an attribute definition. For instance, the base of +/// `irdl.parametric "!builtin.complex"(...)` is `builtin.complex`. +/// This function returns the following information through arguments: +/// - `paramIds`: the set of type or attribute IDs that are used as bases. +/// - `paramIrdlOps`: the set of IRDL operations that are used as bases. +/// - `isIds`: the set of type or attribute IDs that are used in `irdl.is` +/// constraints. +static bool getBases(Operation *op, SmallPtrSet<TypeID, 4> ¶mIds, + SmallPtrSet<Operation *, 4> ¶mIrdlOps, + SmallPtrSet<TypeID, 4> &isIds) { + // For `irdl.any_of`, we get the bases from all its arguments. + if (auto anyOf = dyn_cast<AnyOf>(op)) { + bool has_any = false; + for (Value arg : anyOf.getArgs()) + has_any &= getBases(arg.getDefiningOp(), paramIds, paramIrdlOps, isIds); + return has_any; + } + + // For `irdl.all_of`, we get the bases from the first argument. + // This is restrictive, but we can relax it later if needed. + if (auto allOf = dyn_cast<AllOf>(op)) + return getBases(allOf.getArgs()[0].getDefiningOp(), paramIds, paramIrdlOps, + isIds); + + // For `irdl.parametric`, we get directly the base from the operation. + if (auto params = dyn_cast<Parametric>(op)) { + SymbolRefAttr symRef = params.getBaseType(); + Operation *defOp = SymbolTable::lookupNearestSymbolFrom(op, symRef); + assert(defOp && "symbol reference should refer to an existing operation"); + paramIrdlOps.insert(defOp); + return false; + } + + // For `irdl.is`, we get the base TypeID directly. + if (auto is = dyn_cast<Is>(op)) { + Attribute expected = is.getExpected(); + isIds.insert(expected.getTypeID()); + return false; + } + + // For `irdl.any`, we return `false` since we can match any type or attribute + // base. + if (auto isA = dyn_cast<Any>(op)) + return true; + + llvm_unreachable("unknown IRDL constraint"); +} + +/// Check that an any_of is in the subset IRDL can handle. +/// IRDL uses a greedy algorithm to match constraints. This means that if we +/// encounter an `any_of` with multiple constraints, we will match the first +/// constraint that is satisfied. Thus, the order of constraints matter in +/// `any_of` with our current algorithm. +/// In order to make the order of constraints irrelevant, we require that +/// all `any_of` constraint parameters are disjoint. For this, we check that +/// the base parameters are all disjoints between `parametric` operations, and +/// that they are disjoint between `parametric` and `is` operations. +/// This restriction will be relaxed in the future, when we will change our +/// algorithm to be non-greedy. +static LogicalResult checkCorrectAnyOf(AnyOf anyOf) { + SmallPtrSet<TypeID, 4> paramIds; + SmallPtrSet<Operation *, 4> paramIrdlOps; + SmallPtrSet<TypeID, 4> isIds; + + for (Value arg : anyOf.getArgs()) { + Operation *argOp = arg.getDefiningOp(); + SmallPtrSet<TypeID, 4> argParamIds; + SmallPtrSet<Operation *, 4> argParamIrdlOps; + SmallPtrSet<TypeID, 4> argIsIds; + + // Get the bases of this argument. If it can match any type or attribute, + // then our `any_of` should not be allowed. + if (getBases(argOp, argParamIds, argParamIrdlOps, argIsIds)) + return failure(); + + // We check that the base parameters are all disjoints between `parametric` + // operations, and that they are disjoint between `parametric` and `is` + // operations. + for (TypeID id : argParamIds) { + if (isIds.count(id)) + return failure(); + bool inserted = paramIds.insert(id).second; + if (!inserted) + return failure(); + } + + // We check that the base parameters are all disjoints with `irdl.is` + // operations. + for (TypeID id : isIds) { + if (paramIds.count(id)) + return failure(); + isIds.insert(id); + } + + // We check that all `parametric` operations are disjoint. We do not + // need to check that they are disjoint with `is` operations, since + // `is` operations cannot refer to attributes defined with `irdl.parametric` + // operations. + for (Operation *op : argParamIrdlOps) { + bool inserted = paramIrdlOps.insert(op).second; + if (!inserted) + return failure(); + } + } + + return success(); +} + /// Load all dialects in the given module, without loading any operation, type /// or attribute definitions. static DenseMap<DialectOp, ExtensibleDialect *> loadEmptyDialects(ModuleOp op) { @@ -292,6 +402,13 @@ preallocateAttrDefs(ModuleOp op, } LogicalResult mlir::irdl::loadDialects(ModuleOp op) { + // First, check that all any_of constraints are in a correct form. + // This is to ensure we can do the verification correctly. + WalkResult anyOfCorrects = + op.walk([](AnyOf anyOf) { return (WalkResult)checkCorrectAnyOf(anyOf); }); + if (anyOfCorrects.wasInterrupted()) + return op.emitError("any_of constraints are not in the correct form"); + // Preallocate all dialects, and type and attribute definitions. // In particular, this allocates TypeIDs so type and attributes can have // verifiers that refer to each other. |