summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJacques Pienaar <jpienaar@google.com>2022-01-04 08:28:59 -0800
committerJacques Pienaar <jpienaar@google.com>2022-01-04 08:28:59 -0800
commit05594de2d77b6f4735b8d8d417039b60987b3a79 (patch)
treed75e64dedd7382e15a79f133dabc5e10721e48b7
parentda6b0d0b768e3ecb1af2fd9df2d98510f7aff45c (diff)
downloadllvm-05594de2d77b6f4735b8d8d417039b60987b3a79.tar.gz
[mlir][ods] Handle DeclareOpInterfaceMethods in formatgen
Previously it would not consider ops with DeclareOpInterfaceMethods<InferTypeOpInterface> as having the InferTypeOpInterface interfaces added. The OpInterface nested inside DeclareOpInterfaceMethods is not retained so that one could query it, so check for the the C++ class directly (a bit raw/low level - will be addressed in follow up). Differential Revision: https://reviews.llvm.org/D116572
-rw-r--r--mlir/test/lib/Dialect/Test/TestDialect.cpp9
-rw-r--r--mlir/test/lib/Dialect/Test/TestOps.td6
-rw-r--r--mlir/test/mlir-tblgen/op-format.mlir5
-rw-r--r--mlir/tools/mlir-tblgen/OpFormatGen.cpp13
4 files changed, 29 insertions, 4 deletions
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index aee0bdb13970..441817803ef0 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -264,6 +264,15 @@ Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
return builder.create<TestOpConstant>(loc, type, value);
}
+::mlir::LogicalResult FormatInferType2Op::inferReturnTypes(
+ ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location,
+ ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
+ ::mlir::RegionRange regions,
+ ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
+ inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
+ return ::mlir::success();
+}
+
void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
OperationName opName) {
if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 39f0b0b7da56..6fad11b85ad8 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2139,6 +2139,12 @@ def FormatInferTypeOp : TEST_Op<"format_infer_type", [InferTypeOpInterface]> {
}];
}
+// Check that formatget supports DeclareOpInterfaceMethods.
+def FormatInferType2Op : TEST_Op<"format_infer_type2", [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ let results = (outs AnyType);
+ let assemblyFormat = "attr-dict";
+}
+
// Base class for testing mixing allOperandTypes, allOperands, and
// inferResultTypes.
class FormatInferAllTypesBaseOp<string mnemonic, list<OpTrait> traits = []>
diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index 152cd0a554f1..77afc41f6541 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -409,7 +409,10 @@ test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64
//===----------------------------------------------------------------------===//
// CHECK: test.format_infer_type
-%ignored_res7 = test.format_infer_type
+%ignored_res7a = test.format_infer_type
+
+// CHECK: test.format_infer_type2
+%ignored_res7b = test.format_infer_type2
// CHECK: test.format_infer_type_all_operands_and_types(%[[I64]], %[[I32]]) : i64, i32
%ignored_res8:2 = test.format_infer_type_all_operands_and_types(%i64, %i32) : i64, i32
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 02d0e81b6860..b5218030b64d 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -2345,9 +2345,16 @@ LogicalResult FormatParser::parse() {
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
} else if (def.isSubClassOf("TypesMatchWith")) {
handleTypesMatchConstraint(variableTyResolver, def);
- } else if (def.getName() == "InferTypeOpInterface" &&
- !op.allResultTypesKnown()) {
- canInferResultTypes = true;
+ } else if (!op.allResultTypesKnown()) {
+ // This doesn't check the name directly to handle
+ // DeclareOpInterfaceMethods<InferTypeOpInterface>
+ // and the like.
+ // TODO: Add hasCppInterface check.
+ if (auto name = def.getValueAsOptionalString("cppClassName")) {
+ if (*name == "InferTypeOpInterface" &&
+ def.getValueAsString("cppNamespace") == "::mlir")
+ canInferResultTypes = true;
+ }
}
}