diff options
author | Jacques Pienaar <jpienaar@google.com> | 2022-01-04 08:28:59 -0800 |
---|---|---|
committer | Jacques Pienaar <jpienaar@google.com> | 2022-01-04 08:28:59 -0800 |
commit | 05594de2d77b6f4735b8d8d417039b60987b3a79 (patch) | |
tree | d75e64dedd7382e15a79f133dabc5e10721e48b7 | |
parent | da6b0d0b768e3ecb1af2fd9df2d98510f7aff45c (diff) | |
download | llvm-05594de2d77b6f4735b8d8d417039b60987b3a79.tar.gz |
[mlir][ods] Handle DeclareOpInterfaceMethods in formatgen
Previously it would not consider ops with
DeclareOpInterfaceMethods<InferTypeOpInterface> as having the
InferTypeOpInterface interfaces added. The OpInterface nested inside
DeclareOpInterfaceMethods is not retained so that one could query it, so
check for the the C++ class directly (a bit raw/low level - will be
addressed in follow up).
Differential Revision: https://reviews.llvm.org/D116572
-rw-r--r-- | mlir/test/lib/Dialect/Test/TestDialect.cpp | 9 | ||||
-rw-r--r-- | mlir/test/lib/Dialect/Test/TestOps.td | 6 | ||||
-rw-r--r-- | mlir/test/mlir-tblgen/op-format.mlir | 5 | ||||
-rw-r--r-- | mlir/tools/mlir-tblgen/OpFormatGen.cpp | 13 |
4 files changed, 29 insertions, 4 deletions
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index aee0bdb13970..441817803ef0 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -264,6 +264,15 @@ Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, return builder.create<TestOpConstant>(loc, type, value); } +::mlir::LogicalResult FormatInferType2Op::inferReturnTypes( + ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, + ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, + ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)}); + return ::mlir::success(); +} + void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, OperationName opName) { if (opName.getIdentifier() == "test.unregistered_side_effect_op" && diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 39f0b0b7da56..6fad11b85ad8 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2139,6 +2139,12 @@ def FormatInferTypeOp : TEST_Op<"format_infer_type", [InferTypeOpInterface]> { }]; } +// Check that formatget supports DeclareOpInterfaceMethods. +def FormatInferType2Op : TEST_Op<"format_infer_type2", [DeclareOpInterfaceMethods<InferTypeOpInterface>]> { + let results = (outs AnyType); + let assemblyFormat = "attr-dict"; +} + // Base class for testing mixing allOperandTypes, allOperands, and // inferResultTypes. class FormatInferAllTypesBaseOp<string mnemonic, list<OpTrait> traits = []> diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir index 152cd0a554f1..77afc41f6541 100644 --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -409,7 +409,10 @@ test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64 //===----------------------------------------------------------------------===// // CHECK: test.format_infer_type -%ignored_res7 = test.format_infer_type +%ignored_res7a = test.format_infer_type + +// CHECK: test.format_infer_type2 +%ignored_res7b = test.format_infer_type2 // CHECK: test.format_infer_type_all_operands_and_types(%[[I64]], %[[I32]]) : i64, i32 %ignored_res8:2 = test.format_infer_type_all_operands_and_types(%i64, %i32) : i64, i32 diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index 02d0e81b6860..b5218030b64d 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -2345,9 +2345,16 @@ LogicalResult FormatParser::parse() { handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true); } else if (def.isSubClassOf("TypesMatchWith")) { handleTypesMatchConstraint(variableTyResolver, def); - } else if (def.getName() == "InferTypeOpInterface" && - !op.allResultTypesKnown()) { - canInferResultTypes = true; + } else if (!op.allResultTypesKnown()) { + // This doesn't check the name directly to handle + // DeclareOpInterfaceMethods<InferTypeOpInterface> + // and the like. + // TODO: Add hasCppInterface check. + if (auto name = def.getValueAsOptionalString("cppClassName")) { + if (*name == "InferTypeOpInterface" && + def.getValueAsString("cppNamespace") == "::mlir") + canInferResultTypes = true; + } } } |