summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp
blob: 765bc05a13a9f3792d0521b390533669ef4ca843 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
//===- LoopCoalescing.cpp - Pass transforming loop nests into single loops-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "PassDetail.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/Support/Debug.h"

#define PASS_NAME "loop-coalescing"
#define DEBUG_TYPE PASS_NAME

using namespace mlir;

namespace {
struct LoopCoalescingPass : public LoopCoalescingBase<LoopCoalescingPass> {

  /// Walk either an scf.for or an affine.for to find a band to coalesce.
  template <typename LoopOpTy>
  static void walkLoop(LoopOpTy op) {
    // Ignore nested loops.
    if (op->template getParentOfType<LoopOpTy>())
      return;

    SmallVector<LoopOpTy, 4> loops;
    getPerfectlyNestedLoops(loops, op);
    LLVM_DEBUG(llvm::dbgs()
               << "found a perfect nest of depth " << loops.size() << '\n');

    // Look for a band of loops that can be coalesced, i.e. perfectly nested
    // loops with bounds defined above some loop.
    // 1. For each loop, find above which parent loop its operands are
    // defined.
    SmallVector<unsigned, 4> operandsDefinedAbove(loops.size());
    for (unsigned i = 0, e = loops.size(); i < e; ++i) {
      operandsDefinedAbove[i] = i;
      for (unsigned j = 0; j < i; ++j) {
        if (areValuesDefinedAbove(loops[i].getOperands(),
                                  loops[j].getRegion())) {
          operandsDefinedAbove[i] = j;
          break;
        }
      }
      LLVM_DEBUG(llvm::dbgs()
                 << "  bounds of loop " << i << " are known above depth "
                 << operandsDefinedAbove[i] << '\n');
    }

    // 2. Identify bands of loops such that the operands of all of them are
    // defined above the first loop in the band.  Traverse the nest bottom-up
    // so that modifications don't invalidate the inner loops.
    for (unsigned end = loops.size(); end > 0; --end) {
      unsigned start = 0;
      for (; start < end - 1; ++start) {
        auto maxPos =
            *std::max_element(std::next(operandsDefinedAbove.begin(), start),
                              std::next(operandsDefinedAbove.begin(), end));
        if (maxPos > start)
          continue;

        assert(maxPos == start &&
               "expected loop bounds to be known at the start of the band");
        LLVM_DEBUG(llvm::dbgs() << "  found coalesceable band from " << start
                                << " to " << end << '\n');

        auto band =
            llvm::makeMutableArrayRef(loops.data() + start, end - start);
        (void)coalesceLoops(band);
        break;
      }
      // If a band was found and transformed, keep looking at the loops above
      // the outermost transformed loop.
      if (start != end - 1)
        end = start + 1;
    }
  }

  void runOnOperation() override {
    func::FuncOp func = getOperation();
    func.walk([&](Operation *op) {
      if (auto scfForOp = dyn_cast<scf::ForOp>(op))
        walkLoop(scfForOp);
      else if (auto affineForOp = dyn_cast<AffineForOp>(op))
        walkLoop(affineForOp);
    });
  }
};

} // namespace

std::unique_ptr<OperationPass<func::FuncOp>> mlir::createLoopCoalescingPass() {
  return std::make_unique<LoopCoalescingPass>();
}