summaryrefslogtreecommitdiff
path: root/mlir/lib/Rewrite
diff options
context:
space:
mode:
authorRiver Riddle <riddleriver@gmail.com>2022-04-26 13:38:21 -0700
committerRiver Riddle <riddleriver@gmail.com>2022-05-01 12:25:05 -0700
commit3c752289912895e067eb173485cadce6c618d6d4 (patch)
tree712d9aa23d6ff22c3bc9b5141012b7df6fbf1499 /mlir/lib/Rewrite
parent5387a38c3891730b943a35d9d3e4b2d9e716acda (diff)
downloadllvm-3c752289912895e067eb173485cadce6c618d6d4.tar.gz
[mlir:PDLInterp] Refactor the implementation of result type inferrence
The current implementation uses a discrete "pdl_interp.inferred_types" operation, which acts as a "fake" handle to a type range. This op is used as a signal to pdl_interp.create_operation that types should be inferred. This is terribly awkward and clunky though: * This op doesn't have a byte code representation, and its conversion to bytecode kind of assumes that it is only used in a certain way. The current lowering is also broken and seemingly untested. * Given that this is a different operation, it gives off the assumption that it can be used multiple times, or that after the first use the value contains the inferred types. This isn't the case though, the resultant type range can never actually be used as a type range. This commit refactors the representation by removing the discrete InferredTypesOp, and instead adds a UnitAttr to pdl_interp.CreateOperation that signals when the created operations should infer their types. This leads to a much much cleaner abstraction, a more optimal bytecode lowering, and also allows for better error handling and diagnostics when a created operation doesn't actually support type inferrence. Differential Revision: https://reviews.llvm.org/D124587
Diffstat (limited to 'mlir/lib/Rewrite')
-rw-r--r--mlir/lib/Rewrite/ByteCode.cpp63
1 files changed, 33 insertions, 30 deletions
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index c2dc41a81c6f..ad4c078f2e3a 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -162,6 +162,10 @@ enum OpCode : ByteCodeField {
};
} // namespace
+/// A marker used to indicate if an operation should infer types.
+static constexpr ByteCodeField kInferTypesMarker =
+ std::numeric_limits<ByteCodeField>::max();
+
//===----------------------------------------------------------------------===//
// ByteCode Generation
//===----------------------------------------------------------------------===//
@@ -273,7 +277,6 @@ private:
void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
- void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer);
void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
@@ -723,8 +726,7 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
LLVM_DEBUG({
// The following list must contain all the operations that do not
// produce any bytecode.
- if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp,
- pdl_interp::InferredTypesOp>(op))
+ if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
writer.appendInline(op->getLoc());
});
TypeSwitch<Operation *>(op)
@@ -742,11 +744,11 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
- pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp,
- pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp,
- pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp,
- pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp,
- pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
+ pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
+ pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
+ pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
+ pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
+ pdl_interp::SwitchResultCountOp>(
[&](auto interpOp) { this->generate(interpOp, writer); })
.Default([](Operation *) {
llvm_unreachable("unknown `pdl_interp` operation");
@@ -847,7 +849,13 @@ void Generator::generate(pdl_interp::CreateOperationOp op,
writer.append(static_cast<ByteCodeField>(attributes.size()));
for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
writer.append(std::get<0>(it), std::get<1>(it));
- writer.appendPDLValueList(op.getInputResultTypes());
+
+ // Add the result types. If the operation has inferred results, we use a
+ // marker "size" value. Otherwise, we add the list of explicit result types.
+ if (op.getInferredResultTypes())
+ writer.append(kInferTypesMarker);
+ else
+ writer.appendPDLValueList(op.getInputResultTypes());
}
void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
// Simply repoint the memory index of the result to the constant.
@@ -955,12 +963,6 @@ void Generator::generate(pdl_interp::GetValueTypeOp op,
writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
}
}
-
-void Generator::generate(pdl_interp::InferredTypesOp op,
- ByteCodeWriter &writer) {
- // InferType maps to a null type as a marker for inferring result types.
- getMemIndex(op.getResult()) = getMemIndex(Type());
-}
void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
}
@@ -1526,30 +1528,31 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
state.addAttribute(name, attr);
}
- for (unsigned i = 0, e = read(); i != e; ++i) {
- if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
- state.types.push_back(read<Type>());
- continue;
- }
-
- // If we find a null range, this signals that the types are infered.
- if (TypeRange *resultTypes = read<TypeRange *>()) {
- state.types.append(resultTypes->begin(), resultTypes->end());
- continue;
- }
-
- // Handle the case where the operation has inferred types.
+ // Read in the result types. If the "size" is the sentinel value, this
+ // indicates that the result types should be inferred.
+ unsigned numResults = read();
+ if (numResults == kInferTypesMarker) {
InferTypeOpInterface::Concept *inferInterface =
state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>();
+ assert(inferInterface &&
+ "expected operation to provide InferTypeOpInterface");
// TODO: Handle failure.
- state.types.clear();
if (failed(inferInterface->inferReturnTypes(
state.getContext(), state.location, state.operands,
state.attributes.getDictionary(state.getContext()), state.regions,
state.types)))
return;
- break;
+ } else {
+ // Otherwise, this is a fixed number of results.
+ for (unsigned i = 0; i != numResults; ++i) {
+ if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
+ state.types.push_back(read<Type>());
+ } else {
+ TypeRange *resultTypes = read<TypeRange *>();
+ state.types.append(resultTypes->begin(), resultTypes->end());
+ }
+ }
}
Operation *resultOp = rewriter.create(state);