summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2020-01-13 15:23:01 -0800
committerRiver Riddle <riverriddle@google.com>2020-01-13 15:23:28 -0800
commit6fca03f0cae77c275870c4569bfeeb7ca0f561a6 (patch)
tree875af93a8967eda4aac63649490b2c113b8e5093 /mlir
parent6d57511e0b6f95a369efe7274923a36de3489e7b (diff)
downloadllvm-6fca03f0cae77c275870c4569bfeeb7ca0f561a6.tar.gz
[mlir] Update the use-list algorithms in SymbolTable to support nested references.
Summary: This updates the use list algorithms to support querying from a specific symbol, allowing for the collection and detection of nested references. This works by walking the parent "symbol scopes" and applying the existing algorithm at each level. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D72042
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/IR/SymbolTable.h24
-rw-r--r--mlir/lib/IR/SymbolTable.cpp483
-rw-r--r--mlir/test/IR/test-symbol-rauw.mlir33
-rw-r--r--mlir/test/IR/test-symbol-uses.mlir44
-rw-r--r--mlir/test/lib/IR/TestSymbolUses.cpp102
5 files changed, 529 insertions, 157 deletions
diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h
index 07829186cbf7..2df39ea1b736 100644
--- a/mlir/include/mlir/IR/SymbolTable.h
+++ b/mlir/include/mlir/IR/SymbolTable.h
@@ -50,16 +50,27 @@ public:
// Symbol Utilities
//===--------------------------------------------------------------------===//
+ /// Returns true if the given operation defines a symbol.
+ static bool isSymbol(Operation *op);
+
+ /// Returns the name of the given symbol operation.
+ static StringRef getSymbolName(Operation *symbol);
+ /// Sets the name of the given symbol operation.
+ static void setSymbolName(Operation *symbol, StringRef name);
+
/// Returns the operation registered with the given symbol name with the
/// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
/// with the 'OpTrait::SymbolTable' trait.
static Operation *lookupSymbolIn(Operation *op, StringRef symbol);
+ static Operation *lookupSymbolIn(Operation *op, SymbolRefAttr symbol);
/// Returns the operation registered with the given symbol name within the
/// closest parent operation of, or including, 'from' with the
/// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
/// found.
static Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol);
+ static Operation *lookupNearestSymbolFrom(Operation *from,
+ SymbolRefAttr symbol);
/// This class represents a specific symbol use.
class SymbolUse {
@@ -110,6 +121,7 @@ public:
/// symbol table, and not the op itself. This function returns None if there
/// are any unknown operations that may potentially be symbol tables.
static Optional<UseRange> getSymbolUses(StringRef symbol, Operation *from);
+ static Optional<UseRange> getSymbolUses(Operation *symbol, Operation *from);
/// Return if the given symbol is known to have no uses that are nested
/// within the given operation 'from'. This does not traverse into any nested
@@ -120,6 +132,7 @@ public:
/// tables. This doesn't necessarily mean that there are no uses, we just
/// can't conservatively prove it.
static bool symbolKnownUseEmpty(StringRef symbol, Operation *from);
+ static bool symbolKnownUseEmpty(Operation *symbol, Operation *from);
/// Attempt to replace all uses of the given symbol 'oldSymbol' with the
/// provided symbol 'newSymbol' that are nested within the given operation
@@ -132,6 +145,9 @@ public:
LLVM_NODISCARD static LogicalResult replaceAllSymbolUses(StringRef oldSymbol,
StringRef newSymbol,
Operation *from);
+ LLVM_NODISCARD static LogicalResult
+ replaceAllSymbolUses(Operation *oldSymbol, StringRef newSymbolName,
+ Operation *from);
private:
Operation *symbolTableOp;
@@ -207,14 +223,14 @@ public:
/// operation 'from'.
/// Note: See mlir::SymbolTable::getSymbolUses for more details.
Optional<::mlir::SymbolTable::UseRange> getSymbolUses(Operation *from) {
- return ::mlir::SymbolTable::getSymbolUses(getName(), from);
+ return ::mlir::SymbolTable::getSymbolUses(this->getOperation(), from);
}
/// Return if the current symbol is known to have no uses that are nested
/// within the given operation 'from'.
/// Note: See mlir::SymbolTable::symbolKnownUseEmpty for more details.
bool symbolKnownUseEmpty(Operation *from) {
- return ::mlir::SymbolTable::symbolKnownUseEmpty(getName(), from);
+ return ::mlir::SymbolTable::symbolKnownUseEmpty(this->getOperation(), from);
}
/// Attempt to replace all uses of the current symbol with the provided symbol
@@ -222,8 +238,8 @@ public:
/// Note: See mlir::SymbolTable::replaceAllSymbolUses for more details.
LLVM_NODISCARD LogicalResult replaceAllSymbolUses(StringRef newSymbol,
Operation *from) {
- return ::mlir::SymbolTable::replaceAllSymbolUses(getName(), newSymbol,
- from);
+ return ::mlir::SymbolTable::replaceAllSymbolUses(this->getOperation(),
+ newSymbol, from);
}
};
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 83e5802093c9..057aeded242a 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -7,6 +7,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/SymbolTable.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallString.h"
using namespace mlir;
@@ -17,6 +19,71 @@ static bool isPotentiallyUnknownSymbolTable(Operation *op) {
return !op->getDialect() && op->getNumRegions() == 1;
}
+/// Returns the nearest symbol table from a given operation `from`. Returns
+/// nullptr if no valid parent symbol table could be found.
+static Operation *getNearestSymbolTable(Operation *from) {
+ assert(from && "expected valid operation");
+ if (isPotentiallyUnknownSymbolTable(from))
+ return nullptr;
+
+ while (!from->hasTrait<OpTrait::SymbolTable>()) {
+ from = from->getParentOp();
+
+ // Check that this is a valid op and isn't an unknown symbol table.
+ if (!from || isPotentiallyUnknownSymbolTable(from))
+ return nullptr;
+ }
+ return from;
+}
+
+/// Returns the string name of the given symbol, or None if this is not a
+/// symbol.
+static Optional<StringRef> getNameIfSymbol(Operation *symbol) {
+ auto nameAttr =
+ symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
+ return nameAttr ? nameAttr.getValue() : Optional<StringRef>();
+}
+
+/// Computes the nested symbol reference attribute for the symbol 'symbolName'
+/// that are usable within the symbol table operations from 'symbol' as far up
+/// to the given operation 'within', where 'within' is an ancestor of 'symbol'.
+/// Returns success if all references up to 'within' could be computed.
+static LogicalResult
+collectValidReferencesFor(Operation *symbol, StringRef symbolName,
+ Operation *within,
+ SmallVectorImpl<SymbolRefAttr> &results) {
+ assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor");
+ MLIRContext *ctx = symbol->getContext();
+
+ auto leafRef = FlatSymbolRefAttr::get(symbolName, ctx);
+ results.push_back(leafRef);
+
+ // Early exit for when 'within' is the parent of 'symbol'.
+ Operation *symbolTableOp = symbol->getParentOp();
+ if (within == symbolTableOp)
+ return success();
+
+ // Collect references until 'symbolTableOp' reaches 'within'.
+ SmallVector<FlatSymbolRefAttr, 1> nestedRefs(1, leafRef);
+ do {
+ // Each parent of 'symbol' should define a symbol table.
+ if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
+ return failure();
+ // Each parent of 'symbol' should also be a symbol.
+ Optional<StringRef> symbolTableName = getNameIfSymbol(symbolTableOp);
+ if (!symbolTableName)
+ return failure();
+ results.push_back(SymbolRefAttr::get(*symbolTableName, nestedRefs, ctx));
+
+ symbolTableOp = symbolTableOp->getParentOp();
+ if (symbolTableOp == within)
+ break;
+ nestedRefs.insert(nestedRefs.begin(),
+ FlatSymbolRefAttr::get(*symbolTableName, ctx));
+ } while (true);
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// SymbolTable
//===----------------------------------------------------------------------===//
@@ -32,11 +99,11 @@ SymbolTable::SymbolTable(Operation *symbolTableOp)
"expected operation to have a single block");
for (auto &op : symbolTableOp->getRegion(0).front()) {
- auto nameAttr = op.getAttrOfType<StringAttr>(getSymbolAttrName());
- if (!nameAttr)
+ Optional<StringRef> name = getNameIfSymbol(&op);
+ if (!name)
continue;
- auto inserted = symbolTable.insert({nameAttr.getValue(), &op});
+ auto inserted = symbolTable.insert({*name, &op});
(void)inserted;
assert(inserted.second &&
"expected region to contain uniquely named symbol operations");
@@ -51,13 +118,13 @@ Operation *SymbolTable::lookup(StringRef name) const {
/// Erase the given symbol from the table.
void SymbolTable::erase(Operation *symbol) {
- auto nameAttr = symbol->getAttrOfType<StringAttr>(getSymbolAttrName());
- assert(nameAttr && "expected valid 'name' attribute");
+ Optional<StringRef> name = getNameIfSymbol(symbol);
+ assert(name && "expected valid 'name' attribute");
assert(symbol->getParentOp() == symbolTableOp &&
"expected this operation to be inside of the operation with this "
"SymbolTable");
- auto it = symbolTable.find(nameAttr.getValue());
+ auto it = symbolTable.find(*name);
if (it != symbolTable.end() && it->second == symbol) {
symbolTable.erase(it);
symbol->erase();
@@ -67,9 +134,6 @@ void SymbolTable::erase(Operation *symbol) {
/// Insert a new symbol into the table and associated operation, and rename it
/// as necessary to avoid collisions.
void SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
- auto nameAttr = symbol->getAttrOfType<StringAttr>(getSymbolAttrName());
- assert(nameAttr && "expected valid 'name' attribute");
-
auto &body = symbolTableOp->getRegion(0).front();
if (insertPt == Block::iterator() || insertPt == body.end())
insertPt = Block::iterator(body.getTerminator());
@@ -81,12 +145,12 @@ void SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
// Add this symbol to the symbol table, uniquing the name if a conflict is
// detected.
- if (symbolTable.insert({nameAttr.getValue(), symbol}).second)
+ StringRef name = getSymbolName(symbol);
+ if (symbolTable.insert({name, symbol}).second)
return;
-
// If a conflict was detected, then the symbol will not have been added to
// the symbol table. Try suffixes until we get to a unique name that works.
- SmallString<128> nameBuffer(nameAttr.getValue());
+ SmallString<128> nameBuffer(name);
unsigned originalLength = nameBuffer.size();
// Iteratively try suffixes until we find one that isn't used.
@@ -95,8 +159,24 @@ void SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
nameBuffer += '_';
nameBuffer += std::to_string(uniquingCounter++);
} while (!symbolTable.insert({nameBuffer, symbol}).second);
+ setSymbolName(symbol, nameBuffer);
+}
+
+/// Returns true if the given operation defines a symbol.
+bool SymbolTable::isSymbol(Operation *op) {
+ return op->hasTrait<OpTrait::Symbol>() || getNameIfSymbol(op).hasValue();
+}
+
+/// Returns the name of the given symbol operation.
+StringRef SymbolTable::getSymbolName(Operation *symbol) {
+ Optional<StringRef> name = getNameIfSymbol(symbol);
+ assert(name && "expected valid symbol name");
+ return *name;
+}
+/// Sets the name of the given symbol operation.
+void SymbolTable::setSymbolName(Operation *symbol, StringRef name) {
symbol->setAttr(getSymbolAttrName(),
- StringAttr::get(nameBuffer, symbolTableOp->getContext()));
+ StringAttr::get(name, symbol->getContext()));
}
/// Returns the operation registered with the given symbol name with the
@@ -109,30 +189,52 @@ Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
// Look for a symbol with the given name.
for (auto &block : symbolTableOp->getRegion(0)) {
- for (auto &op : block) {
- auto nameAttr = op.template getAttrOfType<StringAttr>(
- mlir::SymbolTable::getSymbolAttrName());
- if (nameAttr && nameAttr.getValue() == symbol)
+ for (auto &op : block)
+ if (getNameIfSymbol(&op) == symbol)
return &op;
- }
}
return nullptr;
}
+Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
+ SymbolRefAttr symbol) {
+ assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
+
+ // Lookup the root reference for this symbol.
+ symbolTableOp = lookupSymbolIn(symbolTableOp, symbol.getRootReference());
+ if (!symbolTableOp)
+ return nullptr;
+
+ // If there are no nested references, just return the root symbol directly.
+ ArrayRef<FlatSymbolRefAttr> nestedRefs = symbol.getNestedReferences();
+ if (nestedRefs.empty())
+ return symbolTableOp;
+
+ // Verify that the root is also a symbol table.
+ if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
+ return nullptr;
+
+ // Otherwise, lookup each of the nested non-leaf references and ensure that
+ // each corresponds to a valid symbol table.
+ for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) {
+ symbolTableOp = lookupSymbolIn(symbolTableOp, ref.getValue());
+ if (!symbolTableOp || !symbolTableOp->hasTrait<OpTrait::SymbolTable>())
+ return nullptr;
+ }
+ return lookupSymbolIn(symbolTableOp, symbol.getLeafReference());
+}
/// Returns the operation registered with the given symbol name within the
/// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns
/// nullptr if no valid symbol was found.
Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from,
StringRef symbol) {
- assert(from && "expected valid operation");
- while (!from->hasTrait<OpTrait::SymbolTable>()) {
- from = from->getParentOp();
-
- // Check that this is a valid op and isn't an unknown symbol table.
- if (!from || isPotentiallyUnknownSymbolTable(from))
- return nullptr;
- }
- return lookupSymbolIn(from, symbol);
+ Operation *symbolTableOp = getNearestSymbolTable(from);
+ return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
+}
+Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from,
+ SymbolRefAttr symbol) {
+ Operation *symbolTableOp = getNearestSymbolTable(from);
+ return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
}
//===----------------------------------------------------------------------===//
@@ -148,7 +250,7 @@ LogicalResult OpTrait::impl::verifySymbolTable(Operation *op) {
<< "Operations with a 'SymbolTable' must have exactly one block";
// Check that all symbols are uniquely named within child regions.
- llvm::StringMap<Location> nameToOrigLoc;
+ DenseMap<Attribute, Location> nameToOrigLoc;
for (auto &block : op->getRegion(0)) {
for (auto &op : block) {
// Check for a symbol name attribute.
@@ -158,7 +260,7 @@ LogicalResult OpTrait::impl::verifySymbolTable(Operation *op) {
continue;
// Try to insert this symbol into the table.
- auto it = nameToOrigLoc.try_emplace(nameAttr.getValue(), op.getLoc());
+ auto it = nameToOrigLoc.try_emplace(nameAttr, op.getLoc());
if (!it.second)
return op.emitError()
.append("redefinition of symbol named '", nameAttr.getValue(), "'")
@@ -293,6 +395,100 @@ static Optional<WalkResult> walkSymbolUses(
return WalkResult::advance();
}
+/// Walks all of the symbol scopes from 'symbol' to (inclusive) 'limit' invoking
+/// the provided callback at each one with a properly scoped reference to
+/// 'symbol'. The callback takes as parameters the symbol reference at the
+/// current scope as well as the top-level operation representing the top of
+/// that scope.
+static Optional<WalkResult> walkSymbolScopes(
+ Operation *symbol, Operation *limit,
+ function_ref<Optional<WalkResult>(SymbolRefAttr, Operation *)> callback) {
+ StringRef symbolName = SymbolTable::getSymbolName(symbol);
+ assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit);
+
+ // Compute the ancestors of 'limit'.
+ llvm::SetVector<Operation *, SmallVector<Operation *, 4>,
+ SmallPtrSet<Operation *, 4>>
+ limitAncestors;
+ Operation *limitAncestor = limit;
+ do {
+ // Check to see if 'symbol' is an ancestor of 'limit'.
+ if (limitAncestor == symbol) {
+ // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr
+ // doesn't support parent references.
+ if (getNearestSymbolTable(limit) != symbol->getParentOp())
+ return WalkResult::advance();
+ return callback(SymbolRefAttr::get(symbolName, symbol->getContext()),
+ limit);
+ }
+
+ limitAncestors.insert(limitAncestor);
+ } while ((limitAncestor = limitAncestor->getParentOp()));
+
+ // Try to find the first ancestor of 'symbol' that is an ancestor of 'limit'.
+ Operation *commonAncestor = symbol->getParentOp();
+ do {
+ if (limitAncestors.count(commonAncestor))
+ break;
+ } while ((commonAncestor = commonAncestor->getParentOp()));
+ assert(commonAncestor && "'limit' and 'symbol' have no common ancestor");
+
+ // Compute the set of valid nested references for 'symbol' as far up to the
+ // common ancestor as possible.
+ SmallVector<SymbolRefAttr, 2> references;
+ bool collectedAllReferences = succeeded(collectValidReferencesFor(
+ symbol, symbolName, commonAncestor, references));
+
+ // Handle the case where the common ancestor is 'limit'.
+ if (commonAncestor == limit) {
+ // Walk each of the ancestors of 'symbol', calling the compute function for
+ // each one.
+ Operation *limitIt = symbol->getParentOp();
+ for (size_t i = 0, e = references.size(); i != e;
+ ++i, limitIt = limitIt->getParentOp()) {
+ Optional<WalkResult> callbackResult = callback(references[i], limitIt);
+ if (callbackResult != WalkResult::advance())
+ return callbackResult;
+ }
+ return WalkResult::advance();
+ }
+
+ // Otherwise, we just need the symbol reference for 'symbol' that will be
+ // used within 'limit'. This is the last reference in the list we computed
+ // above if we were able to collect all references.
+ if (!collectedAllReferences)
+ return WalkResult::advance();
+ return callback(references.back(), limit);
+}
+
+/// Walk the symbol scopes defined by 'limit' invoking the provided callback.
+static Optional<WalkResult> walkSymbolScopes(
+ StringRef symbol, Operation *limit,
+ function_ref<Optional<WalkResult>(SymbolRefAttr, Operation *)> callback) {
+ return callback(SymbolRefAttr::get(symbol, limit->getContext()), limit);
+}
+
+/// Returns true if the given reference 'SubRef' is a sub reference of the
+/// reference 'ref', i.e. 'ref' is a further qualified reference.
+static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
+ if (ref == subRef)
+ return true;
+
+ // If the references are not pointer equal, check to see if `subRef` is a
+ // prefix of `ref`.
+ if (ref.isa<FlatSymbolRefAttr>() ||
+ ref.getRootReference() != subRef.getRootReference())
+ return false;
+
+ auto refLeafs = ref.getNestedReferences();
+ auto subRefLeafs = subRef.getNestedReferences();
+ return subRefLeafs.size() < refLeafs.size() &&
+ subRefLeafs == refLeafs.take_front(subRefLeafs.size());
+}
+
+//===----------------------------------------------------------------------===//
+// SymbolTable::getSymbolUses
+
/// Get an iterator range for all of the uses, for any symbol, that are nested
/// within the given operation 'from'. This does not traverse into any nested
/// symbol tables, and will also only return uses on 'from' if it does not
@@ -302,14 +498,35 @@ static Optional<WalkResult> walkSymbolUses(
/// tables.
auto SymbolTable::getSymbolUses(Operation *from) -> Optional<UseRange> {
std::vector<SymbolUse> uses;
- Optional<WalkResult> result =
- walkSymbolUses(from, [&](SymbolUse symbolUse, ArrayRef<int>) {
- uses.push_back(symbolUse);
- return WalkResult::advance();
- });
+ auto walkFn = [&](SymbolUse symbolUse, ArrayRef<int>) {
+ uses.push_back(symbolUse);
+ return WalkResult::advance();
+ };
+ auto result = walkSymbolUses(from, walkFn);
return result ? Optional<UseRange>(std::move(uses)) : Optional<UseRange>();
}
+//===----------------------------------------------------------------------===//
+// SymbolTable::getSymbolUses
+
+/// The implementation of SymbolTable::getSymbolUses below.
+template <typename SymbolT>
+static Optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol,
+ Operation *limit) {
+ std::vector<SymbolTable::SymbolUse> uses;
+ auto walkFn = [&](SymbolRefAttr symbolRefAttr, Operation *from) {
+ return walkSymbolUses(
+ from, [&](SymbolTable::SymbolUse symbolUse, ArrayRef<int>) {
+ if (isReferencePrefixOf(symbolRefAttr, symbolUse.getSymbolRef()))
+ uses.push_back(symbolUse);
+ return WalkResult::advance();
+ });
+ };
+ if (walkSymbolScopes(symbol, limit, walkFn))
+ return SymbolTable::UseRange(std::move(uses));
+ return llvm::None;
+}
+
/// Get all of the uses of the given symbol that are nested within the given
/// operation 'from', invoking the provided callback for each. This does not
/// traverse into any nested symbol tables, and will also only return uses on
@@ -319,16 +536,29 @@ auto SymbolTable::getSymbolUses(Operation *from) -> Optional<UseRange> {
/// potentially be symbol tables.
auto SymbolTable::getSymbolUses(StringRef symbol, Operation *from)
-> Optional<UseRange> {
- SymbolRefAttr symbolRefAttr = SymbolRefAttr::get(symbol, from->getContext());
+ return getSymbolUsesImpl(symbol, from);
+}
+auto SymbolTable::getSymbolUses(Operation *symbol, Operation *from)
+ -> Optional<UseRange> {
+ return getSymbolUsesImpl(symbol, from);
+}
- std::vector<SymbolUse> uses;
- Optional<WalkResult> result =
- walkSymbolUses(from, [&](SymbolUse symbolUse, ArrayRef<int>) {
- if (symbolRefAttr == symbolUse.getSymbolRef())
- uses.push_back(symbolUse);
- return WalkResult::advance();
- });
- return result ? Optional<UseRange>(std::move(uses)) : Optional<UseRange>();
+//===----------------------------------------------------------------------===//
+// SymbolTable::symbolKnownUseEmpty
+
+/// The implementation of SymbolTable::symbolKnownUseEmpty below.
+template <typename SymbolT>
+static bool symbolKnownUseEmptyImpl(SymbolT symbol, Operation *limit) {
+ // Walk all of the symbol uses looking for a reference to 'symbol'.
+ auto walkFn = [&](SymbolRefAttr symbolRefAttr, Operation *from) {
+ return walkSymbolUses(
+ from, [&](SymbolTable::SymbolUse symbolUse, ArrayRef<int>) {
+ return isReferencePrefixOf(symbolRefAttr, symbolUse.getSymbolRef())
+ ? WalkResult::interrupt()
+ : WalkResult::advance();
+ });
+ };
+ return walkSymbolScopes(symbol, limit, walkFn) == WalkResult::advance();
}
/// Return if the given symbol is known to have no uses that are nested within
@@ -338,35 +568,32 @@ auto SymbolTable::getSymbolUses(StringRef symbol, Operation *from)
/// symbol table, and not the op itself. This function will also return false if
/// there are any unknown operations that may potentially be symbol tables.
bool SymbolTable::symbolKnownUseEmpty(StringRef symbol, Operation *from) {
- SymbolRefAttr symbolRefAttr = SymbolRefAttr::get(symbol, from->getContext());
-
- // Walk all of the symbol uses looking for a reference to 'symbol'.
- Optional<WalkResult> walkResult =
- walkSymbolUses(from, [&](SymbolUse symbolUse, ArrayRef<int>) {
- return symbolUse.getSymbolRef() == symbolRefAttr
- ? WalkResult::interrupt()
- : WalkResult::advance();
- });
- return walkResult && !walkResult->wasInterrupted();
+ return symbolKnownUseEmptyImpl(symbol, from);
+}
+bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Operation *from) {
+ return symbolKnownUseEmptyImpl(symbol, from);
}
+//===----------------------------------------------------------------------===//
+// SymbolTable::replaceAllSymbolUses
+
/// Rebuild the given attribute container after replacing all references to a
-/// symbol with `newSymAttr`.
-static Attribute rebuildAttrAfterRAUW(Attribute container,
- ArrayRef<SmallVector<int, 1>> accesses,
- SymbolRefAttr newSymAttr,
- unsigned depth) {
+/// symbol with the updated attribute in 'accesses'.
+static Attribute rebuildAttrAfterRAUW(
+ Attribute container,
+ ArrayRef<std::pair<SmallVector<int, 1>, SymbolRefAttr>> accesses,
+ unsigned depth) {
// Given a range of Attributes, update the ones referred to by the given
// access chains to point to the new symbol attribute.
auto updateAttrs = [&](auto &&attrRange) {
auto attrBegin = std::begin(attrRange);
for (unsigned i = 0, e = accesses.size(); i != e;) {
- ArrayRef<int> access = accesses[i];
+ ArrayRef<int> access = accesses[i].first;
Attribute &attr = *std::next(attrBegin, access[depth]);
// Check to see if this is a leaf access, i.e. a SymbolRef.
if (access.size() == depth + 1) {
- attr = newSymAttr;
+ attr = accesses[i].second;
++i;
continue;
}
@@ -374,12 +601,12 @@ static Attribute rebuildAttrAfterRAUW(Attribute container,
// Otherwise, this is a container. Collect all of the accesses for this
// index and recurse. The recursion here is bounded by the size of the
// largest access array.
- auto nestedAccesses =
- accesses.drop_front(i).take_while([&](ArrayRef<int> nextAccess) {
- return nextAccess.size() > depth + 1 &&
- nextAccess[depth] == access[depth];
- });
- attr = rebuildAttrAfterRAUW(attr, nestedAccesses, newSymAttr, depth + 1);
+ auto nestedAccesses = accesses.drop_front(i).take_while([&](auto &it) {
+ ArrayRef<int> nextAccess = it.first;
+ return nextAccess.size() > depth + 1 &&
+ nextAccess[depth] == access[depth];
+ });
+ attr = rebuildAttrAfterRAUW(attr, nestedAccesses, depth + 1);
// Skip over all of the accesses that refer to the nested container.
i += nestedAccesses.size();
@@ -396,64 +623,114 @@ static Attribute rebuildAttrAfterRAUW(Attribute container,
return ArrayAttr::get(newAttrs, container.getContext());
}
-/// Attempt to replace all uses of the given symbol 'oldSymbol' with the
-/// provided symbol 'newSymbol' that are nested within the given operation
-/// 'from'. This does not traverse into any nested symbol tables, and will
-/// also only replace uses on 'from' if it does not also define a symbol
-/// table. This is because we treat the region as the boundary of the symbol
-/// table, and not the op itself. If there are any unknown operations that may
-/// potentially be symbol tables, no uses are replaced and failure is returned.
-LogicalResult SymbolTable::replaceAllSymbolUses(StringRef oldSymbol,
- StringRef newSymbol,
- Operation *from) {
- SymbolRefAttr oldAttr = SymbolRefAttr::get(oldSymbol, from->getContext());
- SymbolRefAttr newSymAttr = SymbolRefAttr::get(newSymbol, from->getContext());
+/// Generates a new symbol reference attribute with a new leaf reference.
+SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
+ FlatSymbolRefAttr newLeafAttr) {
+ if (oldAttr.isa<FlatSymbolRefAttr>())
+ return newLeafAttr;
+ auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences());
+ nestedRefs.back() = newLeafAttr;
+ return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs,
+ oldAttr.getContext());
+}
+/// The implementation of SymbolTable::replaceAllSymbolUses below.
+template <typename SymbolT>
+static LogicalResult replaceAllSymbolUsesImpl(SymbolT symbol,
+ StringRef newSymbol,
+ Operation *limit) {
// A collection of operations along with their new attribute dictionary.
std::vector<std::pair<Operation *, DictionaryAttr>> updatedAttrDicts;
- // The current operation, and its old symbol access chains, being processed.
+ // The current operation being processed.
Operation *curOp = nullptr;
- SmallVector<SmallVector<int, 1>, 1> accessChains;
+
+ // The set of access chains into the attribute dictionary of the current
+ // operation, as well as the replacement attribute to use.
+ SmallVector<std::pair<SmallVector<int, 1>, SymbolRefAttr>, 1> accessChains;
// Generate a new attribute dictionary for the current operation by replacing
// references to the old symbol.
auto generateNewAttrDict = [&] {
- auto newAttrDict =
- rebuildAttrAfterRAUW(curOp->getAttrList().getDictionary(), accessChains,
- newSymAttr, /*depth=*/0);
- return newAttrDict.cast<DictionaryAttr>();
+ auto oldDict = curOp->getAttrList().getDictionary();
+ auto newDict = rebuildAttrAfterRAUW(oldDict, accessChains, /*depth=*/0);
+ return newDict.cast<DictionaryAttr>();
};
- // Walk the symbol uses collecting uses of the old symbol.
- auto walkFn = [&](SymbolTable::SymbolUse symbolUse,
- ArrayRef<int> accessChain) {
- if (symbolUse.getSymbolRef() != oldAttr)
+ // Generate a new attribute to replace the given attribute.
+ MLIRContext *ctx = limit->getContext();
+ FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol, ctx);
+ auto scopeWalkFn = [&](SymbolRefAttr oldAttr,
+ Operation *from) -> Optional<WalkResult> {
+ SymbolRefAttr newAttr = generateNewRefAttr(oldAttr, newLeafAttr);
+ auto walkFn = [&](SymbolTable::SymbolUse symbolUse,
+ ArrayRef<int> accessChain) {
+ SymbolRefAttr useRef = symbolUse.getSymbolRef();
+ if (!isReferencePrefixOf(oldAttr, useRef))
+ return WalkResult::advance();
+
+ // If we have a valid match, check to see if this is a proper
+ // subreference. If it is, then we will need to generate a different new
+ // attribute specifically for this use.
+ SymbolRefAttr replacementRef = newAttr;
+ if (useRef != oldAttr) {
+ if (oldAttr.isa<FlatSymbolRefAttr>()) {
+ replacementRef =
+ SymbolRefAttr::get(newSymbol, useRef.getNestedReferences(), ctx);
+ } else {
+ auto nestedRefs = llvm::to_vector<4>(useRef.getNestedReferences());
+ nestedRefs[oldAttr.getNestedReferences().size() - 1] = newLeafAttr;
+ replacementRef =
+ SymbolRefAttr::get(useRef.getRootReference(), nestedRefs, ctx);
+ }
+ }
+
+ // If there was a previous operation, generate a new attribute dict
+ // for it. This means that we've finished processing the current
+ // operation, so generate a new dictionary for it.
+ if (curOp && symbolUse.getUser() != curOp) {
+ updatedAttrDicts.push_back({curOp, generateNewAttrDict()});
+ accessChains.clear();
+ }
+
+ // Record this access.
+ curOp = symbolUse.getUser();
+ accessChains.push_back({llvm::to_vector<1>(accessChain), replacementRef});
return WalkResult::advance();
+ };
+ if (!walkSymbolUses(from, walkFn))
+ return llvm::None;
- // If there was a previous operation, generate a new attribute dict for it.
- // This means that we've finished processing the current operation, so
- // generate a new dictionary for it.
- if (curOp && symbolUse.getUser() != curOp) {
+ // Check to see if we have a dangling op that needs to be processed.
+ if (curOp) {
updatedAttrDicts.push_back({curOp, generateNewAttrDict()});
- accessChains.clear();
+ curOp = nullptr;
}
-
- // Record this access.
- curOp = symbolUse.getUser();
- accessChains.push_back(llvm::to_vector<1>(accessChain));
return WalkResult::advance();
};
- if (!walkSymbolUses(from, walkFn))
+ if (!walkSymbolScopes(symbol, limit, scopeWalkFn))
return failure();
// Update the attribute dictionaries as necessary.
for (auto &it : updatedAttrDicts)
it.first->setAttrs(it.second);
-
- // Check to see if we have a dangling op that needs to be processed.
- if (curOp)
- curOp->setAttrs(generateNewAttrDict());
-
return success();
}
+
+/// Attempt to replace all uses of the given symbol 'oldSymbol' with the
+/// provided symbol 'newSymbol' that are nested within the given operation
+/// 'from'. This does not traverse into any nested symbol tables, and will
+/// also only replace uses on 'from' if it does not also define a symbol
+/// table. This is because we treat the region as the boundary of the symbol
+/// table, and not the op itself. If there are any unknown operations that may
+/// potentially be symbol tables, no uses are replaced and failure is returned.
+LogicalResult SymbolTable::replaceAllSymbolUses(StringRef oldSymbol,
+ StringRef newSymbol,
+ Operation *from) {
+ return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
+}
+LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
+ StringRef newSymbol,
+ Operation *from) {
+ return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
+}
diff --git a/mlir/test/IR/test-symbol-rauw.mlir b/mlir/test/IR/test-symbol-rauw.mlir
index 963875e762e7..25d22e432256 100644
--- a/mlir/test/IR/test-symbol-rauw.mlir
+++ b/mlir/test/IR/test-symbol-rauw.mlir
@@ -32,6 +32,39 @@ module attributes {sym.outside_use = @symbol_foo } {
// -----
+// Check the support for nested references.
+
+// CHECK: module
+module {
+ // CHECK: module @module_a
+ module @module_a {
+ // CHECK: func @replaced_foo
+ func @foo() attributes {sym.new_name = "replaced_foo" }
+ }
+
+ // CHECK: module @replaced_module_b
+ module @module_b attributes {sym.new_name = "replaced_module_b"} {
+ // CHECK: module @replaced_module_c
+ module @module_c attributes {sym.new_name = "replaced_module_c"} {
+ // CHECK: func @replaced_foo
+ func @foo() attributes {sym.new_name = "replaced_foo" }
+ }
+ }
+
+ // CHECK: func @symbol_bar
+ func @symbol_bar() {
+ // CHECK: foo.op
+ // CHECK-SAME: use_1 = @module_a::@replaced_foo
+ // CHECK-SAME: use_2 = @replaced_module_b::@replaced_module_c::@replaced_foo
+ "foo.op"() {
+ use_1 = @module_a::@foo,
+ use_2 = @module_b::@module_c::@foo
+ } : () -> ()
+ }
+}
+
+// -----
+
// Check that the replacement fails for potentially unknown symbol tables.
module {
// CHECK: func @failed_repl
diff --git a/mlir/test/IR/test-symbol-uses.mlir b/mlir/test/IR/test-symbol-uses.mlir
index 4d6555aceedf..f95ac6b2d701 100644
--- a/mlir/test/IR/test-symbol-uses.mlir
+++ b/mlir/test/IR/test-symbol-uses.mlir
@@ -4,14 +4,14 @@
// its table.
// expected-remark@below {{symbol_removable function successfully erased}}
module attributes {sym.outside_use = @symbol_foo } {
- // expected-remark@+1 {{function has 2 uses}}
+ // expected-remark@+1 {{symbol has 2 uses}}
func @symbol_foo()
- // expected-remark@below {{function has no uses}}
- // expected-remark@below {{found use of function : @symbol_foo}}
- // expected-remark@below {{function contains 2 nested references}}
+ // expected-remark@below {{symbol has no uses}}
+ // expected-remark@below {{found use of symbol : @symbol_foo}}
+ // expected-remark@below {{symbol contains 2 nested references}}
func @symbol_bar() attributes {sym.use = @symbol_foo} {
- // expected-remark@+1 {{found use of function : @symbol_foo}}
+ // expected-remark@+1 {{found use of symbol : @symbol_foo}}
"foo.op"() {
non_symbol_attr,
use = [{ nested_symbol = [@symbol_foo]}],
@@ -19,13 +19,13 @@ module attributes {sym.outside_use = @symbol_foo } {
} : () -> ()
}
- // expected-remark@below {{function has no uses}}
+ // expected-remark@below {{symbol has no uses}}
func @symbol_removable()
- // expected-remark@+1 {{function has 1 use}}
+ // expected-remark@+1 {{symbol has 1 use}}
func @symbol_baz()
- // expected-remark@+1 {{found use of function : @symbol_baz}}
+ // expected-remark@+1 {{found use of symbol : @symbol_baz}}
module attributes {test.reference = @symbol_baz} {
"foo.op"() {test.nested_reference = @symbol_baz} : () -> ()
}
@@ -33,6 +33,34 @@ module attributes {sym.outside_use = @symbol_foo } {
// -----
+// Test nested attribute support
+module {
+ // expected-remark@+1 {{symbol has 2 uses}}
+ module @module_b {
+ // expected-remark@+1 {{symbol has 1 uses}}
+ module @module_c {
+ // expected-remark@+1 {{symbol has 1 uses}}
+ func @foo()
+ }
+ }
+
+ // expected-remark@below {{symbol has no uses}}
+ // expected-remark@below {{symbol contains 2 nested references}}
+ func @symbol_bar() {
+ // expected-remark@below {{found use of symbol : @module_b::@module_c::@foo : "foo"}}
+ // expected-remark@below {{found use of symbol : @module_b::@module_c::@foo : "module_c"}}
+ // expected-remark@below {{found use of symbol : @module_b::@module_c::@foo : "module_b"}}
+ // expected-remark@below {{found use of symbol : @module_b : "module_b"}}
+ "foo.op"() {
+ use_1 = [{ nested_symbol = [@module_b::@module_c::@foo]}],
+ use_2 = @module_b
+ } : () -> ()
+ }
+}
+
+
+// -----
+
// expected-remark@+1 {{contains an unknown nested operation that 'may' define a new symbol table}}
func @symbol_bar() {
"foo.possibly_unknown_symbol_table"() ({
diff --git a/mlir/test/lib/IR/TestSymbolUses.cpp b/mlir/test/lib/IR/TestSymbolUses.cpp
index c8fb1d8eecfc..dc0608b21c51 100644
--- a/mlir/test/lib/IR/TestSymbolUses.cpp
+++ b/mlir/test/lib/IR/TestSymbolUses.cpp
@@ -16,54 +16,70 @@ namespace {
/// This is a symbol test pass that tests the symbol uselist functionality
/// provided by the symbol table along with erasing from the symbol table.
struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
- void runOnModule() override {
- auto module = getModule();
- std::vector<FuncOp> ops_to_delete;
+ WalkResult operateOnSymbol(Operation *symbol, Operation *module,
+ SmallVectorImpl<FuncOp> &deadFunctions) {
+ // Test computing uses on a non symboltable op.
+ Optional<SymbolTable::UseRange> symbolUses =
+ SymbolTable::getSymbolUses(symbol);
- for (FuncOp func : module.getOps<FuncOp>()) {
- // Test computing uses on a non symboltable op.
- Optional<SymbolTable::UseRange> symbolUses =
- SymbolTable::getSymbolUses(func);
+ // Test the conservative failure case.
+ if (!symbolUses) {
+ symbol->emitRemark()
+ << "symbol contains an unknown nested operation that "
+ "'may' define a new symbol table";
+ return WalkResult::interrupt();
+ }
+ if (unsigned numUses = llvm::size(*symbolUses))
+ symbol->emitRemark() << "symbol contains " << numUses
+ << " nested references";
- // Test the conservative failure case.
- if (!symbolUses) {
- func.emitRemark() << "function contains an unknown nested operation "
- "that 'may' define a new symbol table";
- return;
- }
- if (unsigned numUses = llvm::size(*symbolUses))
- func.emitRemark() << "function contains " << numUses
- << " nested references";
+ // Test the functionality of symbolKnownUseEmpty.
+ if (SymbolTable::symbolKnownUseEmpty(symbol, module)) {
+ FuncOp funcSymbol = dyn_cast<FuncOp>(symbol);
+ if (funcSymbol && funcSymbol.isExternal())
+ deadFunctions.push_back(funcSymbol);
- // Test the functionality of symbolKnownUseEmpty.
- if (func.symbolKnownUseEmpty(module)) {
- func.emitRemark() << "function has no uses";
- if (func.getBody().empty())
- ops_to_delete.push_back(func);
- continue;
- }
+ symbol->emitRemark() << "symbol has no uses";
+ return WalkResult::advance();
+ }
- // Test the functionality of getSymbolUses.
- symbolUses = func.getSymbolUses(module);
- assert(symbolUses.hasValue() && "expected no unknown operations");
- for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
+ // Test the functionality of getSymbolUses.
+ symbolUses = SymbolTable::getSymbolUses(symbol, module);
+ assert(symbolUses.hasValue() && "expected no unknown operations");
+ for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
+ // Check that we can resolve back to our symbol.
+ if (Operation *op = SymbolTable::lookupNearestSymbolFrom(
+ symbolUse.getUser()->getParentOp(), symbolUse.getSymbolRef())) {
symbolUse.getUser()->emitRemark()
- << "found use of function : " << symbolUse.getSymbolRef();
+ << "found use of symbol : " << symbolUse.getSymbolRef() << " : "
+ << symbol->getAttr(SymbolTable::getSymbolAttrName());
}
- func.emitRemark() << "function has " << llvm::size(*symbolUses)
- << " uses";
}
+ symbol->emitRemark() << "symbol has " << llvm::size(*symbolUses) << " uses";
+ return WalkResult::advance();
+ }
+
+ void runOnModule() override {
+ auto module = getModule();
- for (FuncOp func : ops_to_delete) {
+ // Walk nested symbols.
+ SmallVector<FuncOp, 4> deadFunctions;
+ module.getBodyRegion().walk([&](Operation *nestedOp) {
+ if (SymbolTable::isSymbol(nestedOp))
+ return operateOnSymbol(nestedOp, module, deadFunctions);
+ return WalkResult::advance();
+ });
+
+ for (Operation *op : deadFunctions) {
// In order to test the SymbolTable::erase method, also erase completely
// useless functions.
SymbolTable table(module);
- auto func_name = func.getName();
- assert(table.lookup(func_name) && "expected no unknown operations");
- table.erase(func);
- assert(!table.lookup(func_name) &&
+ auto name = SymbolTable::getSymbolName(op);
+ assert(table.lookup(name) && "expected no unknown operations");
+ table.erase(op);
+ assert(!table.lookup(name) &&
"expected erased operation to be unknown now");
- module.emitRemark() << func_name << " function successfully erased";
+ module.emitRemark() << name << " function successfully erased";
}
}
};
@@ -74,13 +90,15 @@ struct SymbolReplacementPass : public ModulePass<SymbolReplacementPass> {
void runOnModule() override {
auto module = getModule();
- for (FuncOp func : module.getOps<FuncOp>()) {
- StringAttr newName = func.getAttrOfType<StringAttr>("sym.new_name");
+ // Walk nested functions and modules.
+ module.getBodyRegion().walk([&](Operation *nestedOp) {
+ StringAttr newName = nestedOp->getAttrOfType<StringAttr>("sym.new_name");
if (!newName)
- continue;
- if (succeeded(func.replaceAllSymbolUses(newName.getValue(), module)))
- func.setName(newName.getValue());
- }
+ return;
+ if (succeeded(SymbolTable::replaceAllSymbolUses(
+ nestedOp, newName.getValue(), module)))
+ SymbolTable::setSymbolName(nestedOp, newName.getValue());
+ });
}
};
} // end anonymous namespace