//===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include #include "mlir/IR/BuiltinTypes.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "llvm/ADT/SmallPtrSet.h" using namespace mlir; //===----------------------------------------------------------------------===// // ControlFlowInterfaces //===----------------------------------------------------------------------===// #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc" SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands) : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) { } SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount, MutableOperandRange forwardedOperands) : producedOperandCount(producedOperandCount), forwardedOperands(std::move(forwardedOperands)) {} //===----------------------------------------------------------------------===// // BranchOpInterface //===----------------------------------------------------------------------===// /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some /// successor if 'operandIndex' is within the range of 'operands', or /// std::nullopt if `operandIndex` isn't a successor operand index. std::optional detail::getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor) { OperandRange forwardedOperands = operands.getForwardedOperands(); // Check that the operands are valid. if (forwardedOperands.empty()) return std::nullopt; // Check to ensure that this operand is within the range. unsigned operandsStart = forwardedOperands.getBeginOperandIndex(); if (operandIndex < operandsStart || operandIndex >= (operandsStart + forwardedOperands.size())) return std::nullopt; // Index the successor. unsigned argIndex = operands.getProducedOperandCount() + operandIndex - operandsStart; return successor->getArgument(argIndex); } /// Verify that the given operands match those of the given successor block. LogicalResult detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, const SuccessorOperands &operands) { // Check the count. unsigned operandCount = operands.size(); Block *destBB = op->getSuccessor(succNo); if (operandCount != destBB->getNumArguments()) return op->emitError() << "branch has " << operandCount << " operands for successor #" << succNo << ", but target block has " << destBB->getNumArguments(); // Check the types. for (unsigned i = operands.getProducedOperandCount(); i != operandCount; ++i) { if (!cast(op).areTypesCompatible( operands[i].getType(), destBB->getArgument(i).getType())) return op->emitError() << "type mismatch for bb argument #" << i << " of successor #" << succNo; } return success(); } //===----------------------------------------------------------------------===// // RegionBranchOpInterface //===----------------------------------------------------------------------===// /// Verify that types match along all region control flow edges originating from /// `sourceNo` (region # if source is a region, std::nullopt if source is parent /// op). `getInputsTypesForRegion` is a function that returns the types of the /// inputs that flow from `sourceIndex' to the given region, or std::nullopt if /// the exact type match verification is not necessary (e.g., if the Op verifies /// the match itself). static LogicalResult verifyTypesAlongAllEdges( Operation *op, std::optional sourceNo, function_ref(std::optional)> getInputsTypesForRegion) { auto regionInterface = cast(op); SmallVector successors; regionInterface.getSuccessorRegions(sourceNo, successors); for (RegionSuccessor &succ : successors) { std::optional succRegionNo; if (!succ.isParent()) succRegionNo = succ.getSuccessor()->getRegionNumber(); auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & { diag << "from "; if (sourceNo) diag << "Region #" << sourceNo.value(); else diag << "parent operands"; diag << " to "; if (succRegionNo) diag << "Region #" << succRegionNo.value(); else diag << "parent results"; return diag; }; std::optional sourceTypes = getInputsTypesForRegion(succRegionNo); if (!sourceTypes.has_value()) continue; TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes(); if (sourceTypes->size() != succInputsTypes.size()) { InFlightDiagnostic diag = op->emitOpError(" region control flow edge "); return printEdgeName(diag) << ": source has " << sourceTypes->size() << " operands, but target successor needs " << succInputsTypes.size(); } for (const auto &typesIdx : llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) { Type sourceType = std::get<0>(typesIdx.value()); Type inputType = std::get<1>(typesIdx.value()); if (!regionInterface.areTypesCompatible(sourceType, inputType)) { InFlightDiagnostic diag = op->emitOpError(" along control flow edge "); return printEdgeName(diag) << ": source type #" << typesIdx.index() << " " << sourceType << " should match input type #" << typesIdx.index() << " " << inputType; } } } return success(); } /// Verify that types match along control flow edges described the given op. LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) { auto regionInterface = cast(op); auto inputTypesFromParent = [&](std::optional regionNo) -> TypeRange { return regionInterface.getSuccessorEntryOperands(regionNo).getTypes(); }; // Verify types along control flow edges originating from the parent. if (failed(verifyTypesAlongAllEdges(op, std::nullopt, inputTypesFromParent))) return failure(); auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) { if (lhs.size() != rhs.size()) return false; for (auto types : llvm::zip(lhs, rhs)) { if (!regionInterface.areTypesCompatible(std::get<0>(types), std::get<1>(types))) { return false; } } return true; }; // Verify types along control flow edges originating from each region. for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) { Region ®ion = op->getRegion(regionNo); // Since there can be multiple `ReturnLike` terminators or others // implementing the `RegionBranchTerminatorOpInterface`, all should have the // same operand types when passing them to the same region. std::optional regionReturnOperands; for (Block &block : region) { Operation *terminator = block.getTerminator(); auto terminatorOperands = getRegionBranchSuccessorOperands(terminator, regionNo); if (!terminatorOperands) continue; if (!regionReturnOperands) { regionReturnOperands = terminatorOperands; continue; } // Found more than one ReturnLike terminator. Make sure the operand types // match with the first one. if (!areTypesCompatible(regionReturnOperands->getTypes(), terminatorOperands->getTypes())) return op->emitOpError("Region #") << regionNo << " operands mismatch between return-like terminators"; } auto inputTypesFromRegion = [&](std::optional regionNo) -> std::optional { // If there is no return-like terminator, the op itself should verify // type consistency. if (!regionReturnOperands) return std::nullopt; // All successors get the same set of operand types. return TypeRange(regionReturnOperands->getTypes()); }; if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion))) return failure(); } return success(); } /// Return `true` if region `r` is reachable from region `begin` according to /// the RegionBranchOpInterface (by taking a branch). static bool isRegionReachable(Region *begin, Region *r) { assert(begin->getParentOp() == r->getParentOp() && "expected that both regions belong to the same op"); auto op = cast(begin->getParentOp()); SmallVector visited(op->getNumRegions(), false); visited[begin->getRegionNumber()] = true; // Retrieve all successors of the region and enqueue them in the worklist. SmallVector worklist; auto enqueueAllSuccessors = [&](unsigned index) { SmallVector successors; op.getSuccessorRegions(index, successors); for (RegionSuccessor successor : successors) if (!successor.isParent()) worklist.push_back(successor.getSuccessor()->getRegionNumber()); }; enqueueAllSuccessors(begin->getRegionNumber()); // Process all regions in the worklist via DFS. while (!worklist.empty()) { unsigned nextRegion = worklist.pop_back_val(); if (nextRegion == r->getRegionNumber()) return true; if (visited[nextRegion]) continue; visited[nextRegion] = true; enqueueAllSuccessors(nextRegion); } return false; } /// Return `true` if `a` and `b` are in mutually exclusive regions. /// /// 1. Find the first common of `a` and `b` (ancestor) that implements /// RegionBranchOpInterface. /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are /// contained. /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are /// mutually exclusive if they are not reachable from each other as per /// RegionBranchOpInterface::getSuccessorRegions. bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) { assert(a && "expected non-empty operation"); assert(b && "expected non-empty operation"); auto branchOp = a->getParentOfType(); while (branchOp) { // Check if b is inside branchOp. (We already know that a is.) if (!branchOp->isProperAncestor(b)) { // Check next enclosing RegionBranchOpInterface. branchOp = branchOp->getParentOfType(); continue; } // b is contained in branchOp. Retrieve the regions in which `a` and `b` // are contained. Region *regionA = nullptr, *regionB = nullptr; for (Region &r : branchOp->getRegions()) { if (r.findAncestorOpInRegion(*a)) { assert(!regionA && "already found a region for a"); regionA = &r; } if (r.findAncestorOpInRegion(*b)) { assert(!regionB && "already found a region for b"); regionB = &r; } } assert(regionA && regionB && "could not find region of op"); // `a` and `b` are in mutually exclusive regions if both regions are // distinct and neither region is reachable from the other region. return regionA != regionB && !isRegionReachable(regionA, regionB) && !isRegionReachable(regionB, regionA); } // Could not find a common RegionBranchOpInterface among a's and b's // ancestors. return false; } bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) { Region *region = &getOperation()->getRegion(index); return isRegionReachable(region, region); } void RegionBranchOpInterface::getSuccessorRegions( std::optional index, SmallVectorImpl ®ions) { unsigned numInputs = 0; if (index) { // If the predecessor is a region, get the number of operands from an // exiting terminator in the region. for (Block &block : getOperation()->getRegion(*index)) { Operation *terminator = block.getTerminator(); if (getRegionBranchSuccessorOperands(terminator, *index)) { numInputs = terminator->getNumOperands(); break; } } } else { // Otherwise, use the number of parent operation operands. numInputs = getOperation()->getNumOperands(); } SmallVector operands(numInputs, nullptr); getSuccessorRegions(index, operands, regions); } Region *mlir::getEnclosingRepetitiveRegion(Operation *op) { while (Region *region = op->getParentRegion()) { op = region->getParentOp(); if (auto branchOp = dyn_cast(op)) if (branchOp.isRepetitiveRegion(region->getRegionNumber())) return region; } return nullptr; } Region *mlir::getEnclosingRepetitiveRegion(Value value) { Region *region = value.getParentRegion(); while (region) { Operation *op = region->getParentOp(); if (auto branchOp = dyn_cast(op)) if (branchOp.isRepetitiveRegion(region->getRegionNumber())) return region; region = op->getParentRegion(); } return nullptr; } //===----------------------------------------------------------------------===// // RegionBranchTerminatorOpInterface //===----------------------------------------------------------------------===// /// Returns true if the given operation is either annotated with the /// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`. bool mlir::isRegionReturnLike(Operation *operation) { return dyn_cast(operation) || operation->hasTrait(); } /// Returns the mutable operands that are passed to the region with the given /// `regionIndex`. If the operation does not implement the /// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the /// result will be `std::nullopt`. In all other cases, the resulting /// `OperandRange` represents all operands that are passed to the specified /// successor region. If `regionIndex` is `std::nullopt`, all operands that are /// passed to the parent operation will be returned. std::optional mlir::getMutableRegionBranchSuccessorOperands( Operation *operation, std::optional regionIndex) { // Try to query a RegionBranchTerminatorOpInterface to determine // all successor operands that will be passed to the successor // input arguments. if (auto regionTerminatorInterface = dyn_cast(operation)) return regionTerminatorInterface.getMutableSuccessorOperands(regionIndex); // TODO: The ReturnLike trait should imply a default implementation of the // RegionBranchTerminatorOpInterface. This would make this code significantly // easier. Furthermore, this may even make this function obsolete. if (operation->hasTrait()) return MutableOperandRange(operation); return std::nullopt; } /// Returns the read only operands that are passed to the region with the given /// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more /// information. std::optional mlir::getRegionBranchSuccessorOperands(Operation *operation, std::optional regionIndex) { auto range = getMutableRegionBranchSuccessorOperands(operation, regionIndex); if (range) return range->operator OperandRange(); return std::nullopt; }