summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Zinenko <zinenko@google.com>2020-11-04 09:42:32 +0100
committerAlex Zinenko <zinenko@google.com>2020-11-04 09:43:13 +0100
commit8475fa6ed6bb27d5abad418a7f77e9430aa825eb (patch)
treefd6f5f2f17a8e780fedd5501d8f9b42b0f27ebbf
parent4c0e255c98cc0e7769be9c9b2700d96e76aec99f (diff)
downloadllvm-8475fa6ed6bb27d5abad418a7f77e9430aa825eb.tar.gz
[mlir] Add a simpler lowering pattern for WhileOp representing a do-while loop
When the "after" region of a WhileOp is merely forwarding its arguments back to the "before" region, i.e. WhileOp is a canonical do-while loop, a simpler CFG subgraph that omits the "after" region with its extra branch operation can be produced. Loop rotation from general "while" to "if { do-while }" is left for a future canonicalization pattern when it becomes necessary. Differential Revision: https://reviews.llvm.org/D90604
-rw-r--r--mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp61
-rw-r--r--mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir21
2 files changed, 82 insertions, 0 deletions
diff --git a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
index 953cb27eee74..425131f91a28 100644
--- a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
+++ b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
@@ -266,6 +266,17 @@ struct WhileLowering : public OpRewritePattern<WhileOp> {
LogicalResult matchAndRewrite(WhileOp whileOp,
PatternRewriter &rewriter) const override;
};
+
+/// Optimized version of the above for the case of the "after" region merely
+/// forwarding its arguments back to the "before" region (i.e., a "do-while"
+/// loop). This avoid inlining the "after" region completely and branches back
+/// to the "before" entry instead.
+struct DoWhileLowering : public OpRewritePattern<WhileOp> {
+ using OpRewritePattern<WhileOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(WhileOp whileOp,
+ PatternRewriter &rewriter) const override;
+};
} // namespace
LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
@@ -507,10 +518,60 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
return success();
}
+LogicalResult
+DoWhileLowering::matchAndRewrite(WhileOp whileOp,
+ PatternRewriter &rewriter) const {
+ if (!llvm::hasSingleElement(whileOp.after()))
+ return rewriter.notifyMatchFailure(whileOp,
+ "do-while simplification applicable to "
+ "single-block 'after' region only");
+
+ Block &afterBlock = whileOp.after().front();
+ if (!llvm::hasSingleElement(afterBlock))
+ return rewriter.notifyMatchFailure(whileOp,
+ "do-while simplification applicable "
+ "only if 'after' region has no payload");
+
+ auto yield = dyn_cast<scf::YieldOp>(&afterBlock.front());
+ if (!yield || yield.results() != afterBlock.getArguments())
+ return rewriter.notifyMatchFailure(whileOp,
+ "do-while simplification applicable "
+ "only to forwarding 'after' regions");
+
+ // Split the current block before the WhileOp to create the inlining point.
+ OpBuilder::InsertionGuard guard(rewriter);
+ Block *currentBlock = rewriter.getInsertionBlock();
+ Block *continuation =
+ rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
+
+ // Only the "before" region should be inlined.
+ Block *before = &whileOp.before().front();
+ Block *beforeLast = &whileOp.before().back();
+ rewriter.inlineRegionBefore(whileOp.before(), continuation);
+
+ // Branch to the "before" region.
+ rewriter.setInsertionPointToEnd(currentBlock);
+ rewriter.create<BranchOp>(whileOp.getLoc(), before, whileOp.inits());
+
+ // Loop around the "before" region based on condition.
+ rewriter.setInsertionPointToEnd(beforeLast);
+ auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
+ rewriter.replaceOpWithNewOp<CondBranchOp>(condOp, condOp.condition(), before,
+ condOp.args(), continuation,
+ ValueRange());
+
+ // Replace the op with values "yielded" from the "before" region, which are
+ // visible by dominance.
+ rewriter.replaceOp(whileOp, condOp.args());
+
+ return success();
+}
+
void mlir::populateLoopToStdConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<ForLowering, IfLowering, ParallelLowering, WhileLowering>(
ctx);
+ patterns.insert<DoWhileLowering>(ctx, /*benefit=*/2);
}
void SCFToStandardPass::runOnOperation() {
diff --git a/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir
index 08ad5f1d8976..c3f1325a549b 100644
--- a/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir
+++ b/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir
@@ -424,6 +424,8 @@ func @minimal_while() {
scf.condition(%0)
} do {
// CHECK: ^[[AFTER]]:
+ // CHECK: "test.some_payload"() : () -> ()
+ "test.some_payload"() : () -> ()
// CHECK: br ^[[BEFORE]]
scf.yield
}
@@ -432,6 +434,25 @@ func @minimal_while() {
return
}
+// CHECK-LABEL: @do_while
+func @do_while(%arg0: f32) {
+ // CHECK: br ^[[BEFORE:.*]]({{.*}}: f32)
+ scf.while (%arg1 = %arg0) : (f32) -> (f32) {
+ // CHECK: ^[[BEFORE]](%[[VAL:.*]]: f32):
+ // CHECK: %[[COND:.*]] = "test.make_condition"() : () -> i1
+ %0 = "test.make_condition"() : () -> i1
+ // CHECK: cond_br %[[COND]], ^[[BEFORE]](%[[VAL]] : f32), ^[[CONT:.*]]
+ scf.condition(%0) %arg1 : f32
+ } do {
+ ^bb0(%arg2: f32):
+ // CHECK-NOT: br ^[[BEFORE]]
+ scf.yield %arg2 : f32
+ }
+ // CHECK: ^[[CONT]]:
+ // CHECK: return
+ return
+}
+
// CHECK-LABEL: @while_values
// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: f32)
func @while_values(%arg0: i32, %arg1: f32) {