summaryrefslogtreecommitdiff
path: root/mlir/lib/Reducer/ReductionNode.cpp
blob: 2288b0ac0be151d2b920935dda92e0818e8b9ebe (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
//===- ReductionNode.cpp - Reduction Node Implementation -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the reduction nodes which are used to track of the
// metadata for a specific generated variant within a reduction pass and are the
// building blocks of the reduction tree structure. A reduction tree is used to
// keep track of the different generated variants throughout a reduction pass in
// the MLIR Reduce tool.
//
//===----------------------------------------------------------------------===//

#include "mlir/Reducer/ReductionNode.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "llvm/ADT/STLExtras.h"

#include <algorithm>
#include <limits>

using namespace mlir;

ReductionNode::ReductionNode(
    ReductionNode *parentNode, const std::vector<Range> &ranges,
    llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator)
    /// Root node will have the parent pointer point to themselves.
    : parent(parentNode == nullptr ? this : parentNode),
      size(std::numeric_limits<size_t>::max()), ranges(ranges),
      startRanges(ranges), allocator(allocator) {
  if (parent != this)
    if (failed(initialize(parent->getModule(), parent->getRegion())))
      llvm_unreachable("unexpected initialization failure");
}

LogicalResult ReductionNode::initialize(ModuleOp parentModule,
                                        Region &targetRegion) {
  // Use the mapper help us find the corresponding region after module clone.
  BlockAndValueMapping mapper;
  module = cast<ModuleOp>(parentModule->clone(mapper));
  // Use the first block of targetRegion to locate the cloned region.
  Block *block = mapper.lookup(&*targetRegion.begin());
  region = block->getParent();
  return success();
}

/// If we haven't explored any variants from this node, we will create N
/// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the
/// max element in `ranges` and create 2 new variants for each call.
ArrayRef<ReductionNode *> ReductionNode::generateNewVariants() {
  int oldNumVariant = getVariants().size();

  auto createNewNode = [this](const std::vector<Range> &ranges) {
    return new (allocator.Allocate()) ReductionNode(this, ranges, allocator);
  };

  // If we haven't created new variant, then we can create varients by removing
  // each of them respectively. For example, given {{1, 3}, {4, 9}}, we can
  // produce variants with range {{1, 3}} and {{4, 9}}.
  if (variants.empty() && getRanges().size() > 1) {
    for (const Range &range : getRanges()) {
      std::vector<Range> subRanges = getRanges();
      llvm::erase_value(subRanges, range);
      variants.push_back(createNewNode(std::move(subRanges)));
    }

    return getVariants().drop_front(oldNumVariant);
  }

  // At here, we have created the type of variants mentioned above. We would
  // like to split the max range into 2 to create 2 new variants. Continue on
  // the above example, we split the range {4, 9} into {4, 6}, {6, 9}, and
  // create two variants with range {{1, 3}, {4, 6}} and {{1, 3}, {6, 9}}. The
  // final ranges vector will be {{1, 3}, {4, 6}, {6, 9}}.
  auto maxElement = std::max_element(
      ranges.begin(), ranges.end(), [](const Range &lhs, const Range &rhs) {
        return (lhs.second - lhs.first) > (rhs.second - rhs.first);
      });

  // The length of range is less than 1, we can't split it to create new
  // variant.
  if (maxElement->second - maxElement->first <= 1)
    return {};

  Range maxRange = *maxElement;
  std::vector<Range> subRanges = getRanges();
  auto subRangesIter = subRanges.begin() + (maxElement - ranges.begin());
  int half = (maxRange.first + maxRange.second) / 2;
  *subRangesIter = std::make_pair(maxRange.first, half);
  variants.push_back(createNewNode(subRanges));
  *subRangesIter = std::make_pair(half, maxRange.second);
  variants.push_back(createNewNode(std::move(subRanges)));

  auto it = ranges.insert(maxElement, std::make_pair(half, maxRange.second));
  it = ranges.insert(it, std::make_pair(maxRange.first, half));
  // Remove the range that has been split.
  ranges.erase(it + 2);

  return getVariants().drop_front(oldNumVariant);
}

void ReductionNode::update(std::pair<Tester::Interestingness, size_t> result) {
  std::tie(interesting, size) = result;
  // After applying reduction, the number of operation in the region may have
  // changed. Non-interesting case won't be explored thus it's safe to keep it
  // in a stale status.
  if (interesting == Tester::Interestingness::True) {
    // This module may has been updated. Reset the range.
    ranges.clear();
    ranges.emplace_back(0, std::distance(region->op_begin(), region->op_end()));
  } else {
    // Release the uninteresting module to save some memory.
    module.release()->erase();
  }
}

ArrayRef<ReductionNode *>
ReductionNode::iterator<SinglePath>::getNeighbors(ReductionNode *node) {
  // Single Path: Traverses the smallest successful variant at each level until
  // no new successful variants can be created at that level.
  ArrayRef<ReductionNode *> variantsFromParent =
      node->getParent()->getVariants();

  // The parent node created several variants and they may be waiting for
  // examing interestingness. In Single Path approach, we will select the
  // smallest variant to continue our exploration. Thus we should wait until the
  // last variant to be examed then do the following traversal decision.
  if (!llvm::all_of(variantsFromParent, [](ReductionNode *node) {
        return node->isInteresting() != Tester::Interestingness::Untested;
      })) {
    return {};
  }

  ReductionNode *smallest = nullptr;
  for (ReductionNode *node : variantsFromParent) {
    if (node->isInteresting() != Tester::Interestingness::True)
      continue;
    if (smallest == nullptr || node->getSize() < smallest->getSize())
      smallest = node;
  }

  if (smallest != nullptr &&
      smallest->getSize() < node->getParent()->getSize()) {
    // We got a smallest one, keep traversing from this node.
    node = smallest;
  } else {
    // None of these variants is interesting, let the parent node to generate
    // more variants.
    node = node->getParent();
  }

  return node->generateNewVariants();
}