summaryrefslogtreecommitdiff
path: root/mlir/lib/IR/Block.cpp
blob: feb21493d0d38a952e960ec14b847621096054ea (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
//===- Block.cpp - MLIR Block Class ---------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Operation.h"
#include "llvm/ADT/BitVector.h"
using namespace mlir;

//===----------------------------------------------------------------------===//
// Block
//===----------------------------------------------------------------------===//

Block::~Block() {
  assert(!verifyOpOrder() && "Expected valid operation ordering.");
  clear();
  for (BlockArgument arg : arguments)
    arg.destroy();
}

Region *Block::getParent() const { return parentValidOpOrderPair.getPointer(); }

/// Returns the closest surrounding operation that contains this block or
/// nullptr if this block is unlinked.
Operation *Block::getParentOp() {
  return getParent() ? getParent()->getParentOp() : nullptr;
}

/// Return if this block is the entry block in the parent region.
bool Block::isEntryBlock() { return this == &getParent()->front(); }

/// Insert this block (which must not already be in a region) right before the
/// specified block.
void Block::insertBefore(Block *block) {
  assert(!getParent() && "already inserted into a block!");
  assert(block->getParent() && "cannot insert before a block without a parent");
  block->getParent()->getBlocks().insert(block->getIterator(), this);
}

/// Unlink this block from its current region and insert it right before the
/// specific block.
void Block::moveBefore(Block *block) {
  assert(block->getParent() && "cannot insert before a block without a parent");
  block->getParent()->getBlocks().splice(
      block->getIterator(), getParent()->getBlocks(), getIterator());
}

/// Unlink this Block from its parent Region and delete it.
void Block::erase() {
  assert(getParent() && "Block has no parent");
  getParent()->getBlocks().erase(this);
}

/// Returns 'op' if 'op' lies in this block, or otherwise finds the
/// ancestor operation of 'op' that lies in this block. Returns nullptr if
/// the latter fails.
Operation *Block::findAncestorOpInBlock(Operation &op) {
  // Traverse up the operation hierarchy starting from the owner of operand to
  // find the ancestor operation that resides in the block of 'forOp'.
  auto *currOp = &op;
  while (currOp->getBlock() != this) {
    currOp = currOp->getParentOp();
    if (!currOp)
      return nullptr;
  }
  return currOp;
}

/// This drops all operand uses from operations within this block, which is
/// an essential step in breaking cyclic dependences between references when
/// they are to be deleted.
void Block::dropAllReferences() {
  for (Operation &i : *this)
    i.dropAllReferences();
}

void Block::dropAllDefinedValueUses() {
  for (auto arg : getArguments())
    arg.dropAllUses();
  for (auto &op : *this)
    op.dropAllDefinedValueUses();
  dropAllUses();
}

/// Returns true if the ordering of the child operations is valid, false
/// otherwise.
bool Block::isOpOrderValid() { return parentValidOpOrderPair.getInt(); }

/// Invalidates the current ordering of operations.
void Block::invalidateOpOrder() {
  // Validate the current ordering.
  assert(!verifyOpOrder());
  parentValidOpOrderPair.setInt(false);
}

/// Verifies the current ordering of child operations. Returns false if the
/// order is valid, true otherwise.
bool Block::verifyOpOrder() {
  // The order is already known to be invalid.
  if (!isOpOrderValid())
    return false;
  // The order is valid if there are less than 2 operations.
  if (operations.empty() || std::next(operations.begin()) == operations.end())
    return false;

  Operation *prev = nullptr;
  for (auto &i : *this) {
    // The previous operation must have a smaller order index than the next as
    // it appears earlier in the list.
    if (prev && prev->orderIndex != Operation::kInvalidOrderIdx &&
        prev->orderIndex >= i.orderIndex)
      return true;
    prev = &i;
  }
  return false;
}

/// Recomputes the ordering of child operations within the block.
void Block::recomputeOpOrder() {
  parentValidOpOrderPair.setInt(true);

  unsigned orderIndex = 0;
  for (auto &op : *this)
    op.orderIndex = (orderIndex += Operation::kOrderStride);
}

//===----------------------------------------------------------------------===//
// Argument list management.
//===----------------------------------------------------------------------===//

/// Return a range containing the types of the arguments for this block.
auto Block::getArgumentTypes() -> ValueTypeRange<BlockArgListType> {
  return ValueTypeRange<BlockArgListType>(getArguments());
}

BlockArgument Block::addArgument(Type type, Location loc) {
  BlockArgument arg = BlockArgument::create(type, this, arguments.size(), loc);
  arguments.push_back(arg);
  return arg;
}

/// Add one argument to the argument list for each type specified in the list.
auto Block::addArguments(TypeRange types, ArrayRef<Location> locs)
    -> iterator_range<args_iterator> {
  assert(types.size() == locs.size() &&
         "incorrect number of block argument locations");
  size_t initialSize = arguments.size();
  arguments.reserve(initialSize + types.size());

  for (auto typeAndLoc : llvm::zip(types, locs))
    addArgument(std::get<0>(typeAndLoc), std::get<1>(typeAndLoc));
  return {arguments.data() + initialSize, arguments.data() + arguments.size()};
}

BlockArgument Block::insertArgument(unsigned index, Type type, Location loc) {
  assert(index <= arguments.size() && "invalid insertion index");

  auto arg = BlockArgument::create(type, this, index, loc);
  arguments.insert(arguments.begin() + index, arg);
  // Update the cached position for all the arguments after the newly inserted
  // one.
  ++index;
  for (BlockArgument arg : llvm::drop_begin(arguments, index))
    arg.setArgNumber(index++);
  return arg;
}

/// Insert one value to the given position of the argument list. The existing
/// arguments are shifted. The block is expected not to have predecessors.
BlockArgument Block::insertArgument(args_iterator it, Type type, Location loc) {
  assert(llvm::empty(getPredecessors()) &&
         "cannot insert arguments to blocks with predecessors");
  return insertArgument(it->getArgNumber(), type, loc);
}

void Block::eraseArgument(unsigned index) {
  assert(index < arguments.size());
  arguments[index].destroy();
  arguments.erase(arguments.begin() + index);
  for (BlockArgument arg : llvm::drop_begin(arguments, index))
    arg.setArgNumber(index++);
}

void Block::eraseArguments(ArrayRef<unsigned> argIndices) {
  BitVector eraseIndices(getNumArguments());
  for (unsigned i : argIndices)
    eraseIndices.set(i);
  eraseArguments(eraseIndices);
}

void Block::eraseArguments(const BitVector &eraseIndices) {
  eraseArguments(
      [&](BlockArgument arg) { return eraseIndices.test(arg.getArgNumber()); });
}

void Block::eraseArguments(function_ref<bool(BlockArgument)> shouldEraseFn) {
  auto firstDead = llvm::find_if(arguments, shouldEraseFn);
  if (firstDead == arguments.end())
    return;

  // Destroy the first dead argument, this avoids reapplying the predicate to
  // it.
  unsigned index = firstDead->getArgNumber();
  firstDead->destroy();

  // Iterate the remaining arguments to remove any that are now dead.
  for (auto it = std::next(firstDead), e = arguments.end(); it != e; ++it) {
    // Destroy dead arguments, and shift those that are still live.
    if (shouldEraseFn(*it)) {
      it->destroy();
    } else {
      it->setArgNumber(index++);
      *firstDead++ = *it;
    }
  }
  arguments.erase(firstDead, arguments.end());
}

//===----------------------------------------------------------------------===//
// Terminator management
//===----------------------------------------------------------------------===//

/// Get the terminator operation of this block. This function asserts that
/// the block has a valid terminator operation.
Operation *Block::getTerminator() {
  assert(!empty() && back().mightHaveTrait<OpTrait::IsTerminator>());
  return &back();
}

// Indexed successor access.
unsigned Block::getNumSuccessors() {
  return empty() ? 0 : back().getNumSuccessors();
}

Block *Block::getSuccessor(unsigned i) {
  assert(i < getNumSuccessors());
  return getTerminator()->getSuccessor(i);
}

/// If this block has exactly one predecessor, return it.  Otherwise, return
/// null.
///
/// Note that multiple edges from a single block (e.g. if you have a cond
/// branch with the same block as the true/false destinations) is not
/// considered to be a single predecessor.
Block *Block::getSinglePredecessor() {
  auto it = pred_begin();
  if (it == pred_end())
    return nullptr;
  auto *firstPred = *it;
  ++it;
  return it == pred_end() ? firstPred : nullptr;
}

/// If this block has a unique predecessor, i.e., all incoming edges originate
/// from one block, return it. Otherwise, return null.
Block *Block::getUniquePredecessor() {
  auto it = pred_begin(), e = pred_end();
  if (it == e)
    return nullptr;

  // Check for any conflicting predecessors.
  auto *firstPred = *it;
  for (++it; it != e; ++it)
    if (*it != firstPred)
      return nullptr;
  return firstPred;
}

//===----------------------------------------------------------------------===//
// Other
//===----------------------------------------------------------------------===//

/// Split the block into two blocks before the specified operation or
/// iterator.
///
/// Note that all operations BEFORE the specified iterator stay as part of
/// the original basic block, and the rest of the operations in the original
/// block are moved to the new block, including the old terminator.  The
/// original block is left without a terminator.
///
/// The newly formed Block is returned, and the specified iterator is
/// invalidated.
Block *Block::splitBlock(iterator splitBefore) {
  // Start by creating a new basic block, and insert it immediate after this
  // one in the containing region.
  auto *newBB = new Block();
  getParent()->getBlocks().insert(std::next(Region::iterator(this)), newBB);

  // Move all of the operations from the split point to the end of the region
  // into the new block.
  newBB->getOperations().splice(newBB->end(), getOperations(), splitBefore,
                                end());
  return newBB;
}

//===----------------------------------------------------------------------===//
// Predecessors
//===----------------------------------------------------------------------===//

Block *PredecessorIterator::unwrap(BlockOperand &value) {
  return value.getOwner()->getBlock();
}

/// Get the successor number in the predecessor terminator.
unsigned PredecessorIterator::getSuccessorIndex() const {
  return I->getOperandNumber();
}

//===----------------------------------------------------------------------===//
// SuccessorRange
//===----------------------------------------------------------------------===//

SuccessorRange::SuccessorRange() : SuccessorRange(nullptr, 0) {}

SuccessorRange::SuccessorRange(Block *block) : SuccessorRange() {
 if (block->empty() || llvm::hasSingleElement(*block->getParent()))
  return;
 Operation *term = &block->back();
 if ((count = term->getNumSuccessors()))
   base = term->getBlockOperands().data();
}

SuccessorRange::SuccessorRange(Operation *term) : SuccessorRange() {
  if ((count = term->getNumSuccessors()))
    base = term->getBlockOperands().data();
}

//===----------------------------------------------------------------------===//
// BlockRange
//===----------------------------------------------------------------------===//

BlockRange::BlockRange(ArrayRef<Block *> blocks) : BlockRange(nullptr, 0) {
  if ((count = blocks.size()))
    base = blocks.data();
}

BlockRange::BlockRange(SuccessorRange successors)
    : BlockRange(successors.begin().getBase(), successors.size()) {}

/// See `llvm::detail::indexed_accessor_range_base` for details.
BlockRange::OwnerT BlockRange::offset_base(OwnerT object, ptrdiff_t index) {
  if (auto *operand = object.dyn_cast<BlockOperand *>())
    return {operand + index};
  return {object.dyn_cast<Block *const *>() + index};
}

/// See `llvm::detail::indexed_accessor_range_base` for details.
Block *BlockRange::dereference_iterator(OwnerT object, ptrdiff_t index) {
  if (const auto *operand = object.dyn_cast<BlockOperand *>())
    return operand[index].get();
  return object.dyn_cast<Block *const *>()[index];
}