summaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
authorMathieu Fehr <mathieu.fehr@gmail.com>2023-03-08 23:16:02 +0100
committerMathieu Fehr <mathieu.fehr@gmail.com>2023-05-17 13:34:00 +0100
commitc8a581c331f27a1ece8e42206831e56b7a222d26 (patch)
tree2f243907cb494873ba73821f612a1823fa90b0f9 /mlir/lib
parent4f30a63ca2a6cbc16beaa49df16373d020118e92 (diff)
downloadllvm-c8a581c331f27a1ece8e42206831e56b7a222d26.tar.gz
[mlir][irdl] Add verification of IRDL ops
This patch adds verification on registered IRDL operations, types, and attributes. This is done through an interface implemented by operations from the `irdl` dialect, which translate the operations into `Constraint`. This interface is then use in the `registerDialect` function to generate verifiers for the entire operation/type/attribute. Depends on D145733 Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D145734
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Dialect/IRDL/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/IRDL/IR/IRDL.cpp2
-rw-r--r--mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp61
-rw-r--r--mlir/lib/Dialect/IRDL/IRDLLoading.cpp225
4 files changed, 285 insertions, 4 deletions
diff --git a/mlir/lib/Dialect/IRDL/CMakeLists.txt b/mlir/lib/Dialect/IRDL/CMakeLists.txt
index 7af0e4229357..d25760e5d29b 100644
--- a/mlir/lib/Dialect/IRDL/CMakeLists.txt
+++ b/mlir/lib/Dialect/IRDL/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRIRDL
IR/IRDL.cpp
+ IR/IRDLOps.cpp
IRDLLoading.cpp
IRDLVerifiers.cpp
diff --git a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
index e2649f22094c..01e58ccf0a5b 100644
--- a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
+++ b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
@@ -71,6 +71,8 @@ LogicalResult DialectOp::verify() {
return success();
}
+#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.cpp.inc"
+
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
diff --git a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
new file mode 100644
index 000000000000..a9956cc630cc
--- /dev/null
+++ b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
@@ -0,0 +1,61 @@
+//===- IRDLOps.cpp - IRDL dialect -------------------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/IRDL/IR/IRDL.h"
+
+using namespace mlir;
+using namespace mlir::irdl;
+
+std::unique_ptr<Constraint> Is::getVerifier(
+ SmallVector<Value> const &valueToConstr,
+ DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
+ DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
+ return std::make_unique<IsConstraint>(getExpectedAttr());
+}
+
+std::unique_ptr<Constraint> Parametric::getVerifier(
+ SmallVector<Value> const &valueToConstr,
+ DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
+ DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
+ SmallVector<unsigned> constraints;
+ for (Value arg : getArgs()) {
+ for (auto [i, value] : enumerate(valueToConstr)) {
+ if (value == arg) {
+ constraints.push_back(i);
+ break;
+ }
+ }
+ }
+
+ // Symbol reference case for the base
+ SymbolRefAttr symRef = getBaseType();
+ Operation *defOp =
+ SymbolTable::lookupNearestSymbolFrom(getOperation(), symRef);
+ if (!defOp) {
+ emitError() << symRef << " does not refer to any existing symbol";
+ return nullptr;
+ }
+
+ if (auto typeOp = dyn_cast<TypeOp>(defOp))
+ return std::make_unique<DynParametricTypeConstraint>(types[typeOp].get(),
+ constraints);
+
+ if (auto attrOp = dyn_cast<AttributeOp>(defOp))
+ return std::make_unique<DynParametricAttrConstraint>(attrs[attrOp].get(),
+ constraints);
+
+ llvm_unreachable("verifier should ensure that the referenced operation is "
+ "either a type or an attribute definition");
+}
+
+std::unique_ptr<Constraint> Any::getVerifier(
+ SmallVector<Value> const &valueToConstr,
+ DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
+ DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
+ return std::make_unique<AnyAttributeConstraint>();
+}
diff --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
index fb00085a7ee0..f65d0eceb03a 100644
--- a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
+++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/IRDL/IRDLLoading.h"
#include "mlir/Dialect/IRDL/IR/IRDL.h"
+#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/ExtensibleDialect.h"
#include "mlir/Support/LogicalResult.h"
@@ -22,9 +23,130 @@
using namespace mlir;
using namespace mlir::irdl;
+/// Verify that the given list of parameters satisfy the given constraints.
+/// This encodes the logic of the verification method for attributes and types
+/// defined with IRDL.
+static LogicalResult
+irdlAttrOrTypeVerifier(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<Attribute> params,
+ ArrayRef<std::unique_ptr<Constraint>> constraints,
+ ArrayRef<size_t> paramConstraints) {
+ if (params.size() != paramConstraints.size()) {
+ emitError() << "expected " << paramConstraints.size()
+ << " type arguments, but had " << params.size();
+ return failure();
+ }
+
+ ConstraintVerifier verifier(constraints);
+
+ // Check that each parameter satisfies its constraint.
+ for (auto [i, param] : enumerate(params))
+ if (failed(verifier.verify(emitError, param, paramConstraints[i])))
+ return failure();
+
+ return success();
+}
+
+/// Verify that the given operation satisfies the given constraints.
+/// This encodes the logic of the verification method for operations defined
+/// with IRDL.
+static LogicalResult
+irdlOpVerifier(Operation *op, ArrayRef<std::unique_ptr<Constraint>> constraints,
+ ArrayRef<size_t> operandConstrs,
+ ArrayRef<size_t> resultConstrs) {
+ /// Check that we have the right number of operands.
+ unsigned numOperands = op->getNumOperands();
+ size_t numExpectedOperands = operandConstrs.size();
+ if (numOperands != numExpectedOperands)
+ return op->emitOpError() << numExpectedOperands
+ << " operands expected, but got " << numOperands;
+
+ /// Check that we have the right number of results.
+ unsigned numResults = op->getNumResults();
+ size_t numExpectedResults = resultConstrs.size();
+ if (numResults != numExpectedResults)
+ return op->emitOpError()
+ << numExpectedResults << " results expected, but got " << numResults;
+
+ auto emitError = [op]() { return op->emitError(); };
+
+ ConstraintVerifier verifier(constraints);
+
+ /// Check that all operands satisfy the constraints.
+ for (auto [i, operandType] : enumerate(op->getOperandTypes()))
+ if (failed(verifier.verify({emitError}, TypeAttr::get(operandType),
+ operandConstrs[i])))
+ return failure();
+
+ /// Check that all results satisfy the constraints.
+ for (auto [i, resultType] : enumerate(op->getResultTypes()))
+ if (failed(verifier.verify({emitError}, TypeAttr::get(resultType),
+ resultConstrs[i])))
+ return failure();
+
+ return success();
+}
+
/// Define and load an operation represented by a `irdl.operation`
/// operation.
-static WalkResult loadOperation(OperationOp op, ExtensibleDialect *dialect) {
+static WalkResult loadOperation(
+ OperationOp op, ExtensibleDialect *dialect,
+ DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
+ DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
+ // Resolve SSA values to verifier constraint slots
+ SmallVector<Value> constrToValue;
+ for (Operation &op : op->getRegion(0).getOps()) {
+ if (isa<VerifyConstraintInterface>(op)) {
+ if (op.getNumResults() != 1)
+ return op.emitError()
+ << "IRDL constraint operations must have exactly one result";
+ constrToValue.push_back(op.getResult(0));
+ }
+ }
+
+ // Build the verifiers for each constraint slot
+ SmallVector<std::unique_ptr<Constraint>> constraints;
+ for (Value v : constrToValue) {
+ VerifyConstraintInterface op =
+ cast<VerifyConstraintInterface>(v.getDefiningOp());
+ std::unique_ptr<Constraint> verifier =
+ op.getVerifier(constrToValue, types, attrs);
+ if (!verifier)
+ return WalkResult::interrupt();
+ constraints.push_back(std::move(verifier));
+ }
+
+ SmallVector<size_t> operandConstraints;
+ SmallVector<size_t> resultConstraints;
+
+ // Gather which constraint slots correspond to operand constraints
+ auto operandsOp = op.getOp<OperandsOp>();
+ if (operandsOp.has_value()) {
+ operandConstraints.reserve(operandsOp->getArgs().size());
+ for (Value operand : operandsOp->getArgs()) {
+ for (auto [i, constr] : enumerate(constrToValue)) {
+ if (constr == operand) {
+ operandConstraints.push_back(i);
+ break;
+ }
+ }
+ }
+ }
+
+ // Gather which constraint slots correspond to result constraints
+ auto resultsOp = op.getOp<ResultsOp>();
+ if (resultsOp.has_value()) {
+ resultConstraints.reserve(resultsOp->getArgs().size());
+ for (Value result : resultsOp->getArgs()) {
+ for (auto [i, constr] : enumerate(constrToValue)) {
+ if (constr == result) {
+ resultConstraints.push_back(i);
+ break;
+ }
+ }
+ }
+ }
+
// IRDL does not support defining custom parsers or printers.
auto parser = [](OpAsmParser &parser, OperationState &result) {
return failure();
@@ -33,7 +155,13 @@ static WalkResult loadOperation(OperationOp op, ExtensibleDialect *dialect) {
printer.printGenericOp(op);
};
- auto verifier = [](Operation *op) { return success(); };
+ auto verifier =
+ [constraints{std::move(constraints)},
+ operandConstraints{std::move(operandConstraints)},
+ resultConstraints{std::move(resultConstraints)}](Operation *op) {
+ return irdlOpVerifier(op, constraints, operandConstraints,
+ resultConstraints);
+ };
// IRDL does not support defining regions.
auto regionVerifier = [](Operation *op) { return success(); };
@@ -46,6 +174,71 @@ static WalkResult loadOperation(OperationOp op, ExtensibleDialect *dialect) {
return WalkResult::advance();
}
+/// Get the verifier of a type or attribute definition.
+/// Return nullptr if the definition is invalid.
+static DynamicAttrDefinition::VerifierFn getAttrOrTypeVerifier(
+ Operation *attrOrTypeDef, ExtensibleDialect *dialect,
+ DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
+ DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
+ assert((isa<AttributeOp>(attrOrTypeDef) || isa<TypeOp>(attrOrTypeDef)) &&
+ "Expected an attribute or type definition");
+
+ // Resolve SSA values to verifier constraint slots
+ SmallVector<Value> constrToValue;
+ for (Operation &op : attrOrTypeDef->getRegion(0).getOps()) {
+ if (isa<VerifyConstraintInterface>(op)) {
+ assert(op.getNumResults() == 1 &&
+ "IRDL constraint operations must have exactly one result");
+ constrToValue.push_back(op.getResult(0));
+ }
+ }
+
+ // Build the verifiers for each constraint slot
+ SmallVector<std::unique_ptr<Constraint>> constraints;
+ for (Value v : constrToValue) {
+ VerifyConstraintInterface op =
+ cast<VerifyConstraintInterface>(v.getDefiningOp());
+ std::unique_ptr<Constraint> verifier =
+ op.getVerifier(constrToValue, types, attrs);
+ if (!verifier)
+ return {};
+ constraints.push_back(std::move(verifier));
+ }
+
+ // Get the parameter definitions.
+ std::optional<ParametersOp> params;
+ if (auto attr = dyn_cast<AttributeOp>(attrOrTypeDef))
+ params = attr.getOp<ParametersOp>();
+ else if (auto type = dyn_cast<TypeOp>(attrOrTypeDef))
+ params = type.getOp<ParametersOp>();
+
+ // Gather which constraint slots correspond to parameter constraints
+ SmallVector<size_t> paramConstraints;
+ if (params.has_value()) {
+ paramConstraints.reserve(params->getArgs().size());
+ for (Value param : params->getArgs()) {
+ for (auto [i, constr] : enumerate(constrToValue)) {
+ if (constr == param) {
+ paramConstraints.push_back(i);
+ break;
+ }
+ }
+ }
+ }
+
+ auto verifier = [paramConstraints{std::move(paramConstraints)},
+ constraints{std::move(constraints)}](
+ function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<Attribute> params) {
+ return irdlAttrOrTypeVerifier(emitError, params, constraints,
+ paramConstraints);
+ };
+
+ // While the `std::move` is not required, not adding it triggers a bug in
+ // clang-10.
+ return std::move(verifier);
+}
+
/// Load all dialects in the given module, without loading any operation, type
/// or attribute definitions.
static DenseMap<DialectOp, ExtensibleDialect *> loadEmptyDialects(ModuleOp op) {
@@ -108,9 +301,33 @@ LogicalResult mlir::irdl::loadDialects(ModuleOp op) {
DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> attrs =
preallocateAttrDefs(op, dialects);
+ // Set the verifier for types.
+ WalkResult res = op.walk([&](TypeOp typeOp) {
+ DynamicAttrDefinition::VerifierFn verifier = getAttrOrTypeVerifier(
+ typeOp, dialects[typeOp.getParentOp()], types, attrs);
+ if (!verifier)
+ return WalkResult::interrupt();
+ types[typeOp]->setVerifyFn(std::move(verifier));
+ return WalkResult::advance();
+ });
+ if (res.wasInterrupted())
+ return failure();
+
+ // Set the verifier for attributes.
+ res = op.walk([&](AttributeOp attrOp) {
+ DynamicAttrDefinition::VerifierFn verifier = getAttrOrTypeVerifier(
+ attrOp, dialects[attrOp.getParentOp()], types, attrs);
+ if (!verifier)
+ return WalkResult::interrupt();
+ attrs[attrOp]->setVerifyFn(std::move(verifier));
+ return WalkResult::advance();
+ });
+ if (res.wasInterrupted())
+ return failure();
+
// Define and load all operations.
- WalkResult res = op.walk([&](OperationOp opOp) {
- return loadOperation(opOp, dialects[opOp.getParentOp()]);
+ res = op.walk([&](OperationOp opOp) {
+ return loadOperation(opOp, dialects[opOp.getParentOp()], types, attrs);
});
if (res.wasInterrupted())
return failure();