diff options
author | Alex Zinenko <zinenko@google.com> | 2020-11-04 09:42:32 +0100 |
---|---|---|
committer | Alex Zinenko <zinenko@google.com> | 2020-11-04 09:43:13 +0100 |
commit | 8475fa6ed6bb27d5abad418a7f77e9430aa825eb (patch) | |
tree | fd6f5f2f17a8e780fedd5501d8f9b42b0f27ebbf | |
parent | 4c0e255c98cc0e7769be9c9b2700d96e76aec99f (diff) | |
download | llvm-8475fa6ed6bb27d5abad418a7f77e9430aa825eb.tar.gz |
[mlir] Add a simpler lowering pattern for WhileOp representing a do-while loop
When the "after" region of a WhileOp is merely forwarding its arguments back to
the "before" region, i.e. WhileOp is a canonical do-while loop, a simpler CFG
subgraph that omits the "after" region with its extra branch operation can be
produced. Loop rotation from general "while" to "if { do-while }" is left for a
future canonicalization pattern when it becomes necessary.
Differential Revision: https://reviews.llvm.org/D90604
-rw-r--r-- | mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp | 61 | ||||
-rw-r--r-- | mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir | 21 |
2 files changed, 82 insertions, 0 deletions
diff --git a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp index 953cb27eee74..425131f91a28 100644 --- a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp +++ b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp @@ -266,6 +266,17 @@ struct WhileLowering : public OpRewritePattern<WhileOp> { LogicalResult matchAndRewrite(WhileOp whileOp, PatternRewriter &rewriter) const override; }; + +/// Optimized version of the above for the case of the "after" region merely +/// forwarding its arguments back to the "before" region (i.e., a "do-while" +/// loop). This avoid inlining the "after" region completely and branches back +/// to the "before" entry instead. +struct DoWhileLowering : public OpRewritePattern<WhileOp> { + using OpRewritePattern<WhileOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(WhileOp whileOp, + PatternRewriter &rewriter) const override; +}; } // namespace LogicalResult ForLowering::matchAndRewrite(ForOp forOp, @@ -507,10 +518,60 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp, return success(); } +LogicalResult +DoWhileLowering::matchAndRewrite(WhileOp whileOp, + PatternRewriter &rewriter) const { + if (!llvm::hasSingleElement(whileOp.after())) + return rewriter.notifyMatchFailure(whileOp, + "do-while simplification applicable to " + "single-block 'after' region only"); + + Block &afterBlock = whileOp.after().front(); + if (!llvm::hasSingleElement(afterBlock)) + return rewriter.notifyMatchFailure(whileOp, + "do-while simplification applicable " + "only if 'after' region has no payload"); + + auto yield = dyn_cast<scf::YieldOp>(&afterBlock.front()); + if (!yield || yield.results() != afterBlock.getArguments()) + return rewriter.notifyMatchFailure(whileOp, + "do-while simplification applicable " + "only to forwarding 'after' regions"); + + // Split the current block before the WhileOp to create the inlining point. + OpBuilder::InsertionGuard guard(rewriter); + Block *currentBlock = rewriter.getInsertionBlock(); + Block *continuation = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + + // Only the "before" region should be inlined. + Block *before = &whileOp.before().front(); + Block *beforeLast = &whileOp.before().back(); + rewriter.inlineRegionBefore(whileOp.before(), continuation); + + // Branch to the "before" region. + rewriter.setInsertionPointToEnd(currentBlock); + rewriter.create<BranchOp>(whileOp.getLoc(), before, whileOp.inits()); + + // Loop around the "before" region based on condition. + rewriter.setInsertionPointToEnd(beforeLast); + auto condOp = cast<ConditionOp>(beforeLast->getTerminator()); + rewriter.replaceOpWithNewOp<CondBranchOp>(condOp, condOp.condition(), before, + condOp.args(), continuation, + ValueRange()); + + // Replace the op with values "yielded" from the "before" region, which are + // visible by dominance. + rewriter.replaceOp(whileOp, condOp.args()); + + return success(); +} + void mlir::populateLoopToStdConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { patterns.insert<ForLowering, IfLowering, ParallelLowering, WhileLowering>( ctx); + patterns.insert<DoWhileLowering>(ctx, /*benefit=*/2); } void SCFToStandardPass::runOnOperation() { diff --git a/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir index 08ad5f1d8976..c3f1325a549b 100644 --- a/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir +++ b/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir @@ -424,6 +424,8 @@ func @minimal_while() { scf.condition(%0) } do { // CHECK: ^[[AFTER]]: + // CHECK: "test.some_payload"() : () -> () + "test.some_payload"() : () -> () // CHECK: br ^[[BEFORE]] scf.yield } @@ -432,6 +434,25 @@ func @minimal_while() { return } +// CHECK-LABEL: @do_while +func @do_while(%arg0: f32) { + // CHECK: br ^[[BEFORE:.*]]({{.*}}: f32) + scf.while (%arg1 = %arg0) : (f32) -> (f32) { + // CHECK: ^[[BEFORE]](%[[VAL:.*]]: f32): + // CHECK: %[[COND:.*]] = "test.make_condition"() : () -> i1 + %0 = "test.make_condition"() : () -> i1 + // CHECK: cond_br %[[COND]], ^[[BEFORE]](%[[VAL]] : f32), ^[[CONT:.*]] + scf.condition(%0) %arg1 : f32 + } do { + ^bb0(%arg2: f32): + // CHECK-NOT: br ^[[BEFORE]] + scf.yield %arg2 : f32 + } + // CHECK: ^[[CONT]]: + // CHECK: return + return +} + // CHECK-LABEL: @while_values // CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: f32) func @while_values(%arg0: i32, %arg1: f32) { |