diff options
Diffstat (limited to 'mlir/lib/IR/SymbolTable.cpp')
-rw-r--r-- | mlir/lib/IR/SymbolTable.cpp | 483 |
1 files changed, 380 insertions, 103 deletions
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp index 83e5802093c9..057aeded242a 100644 --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -7,6 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/SymbolTable.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallString.h" using namespace mlir; @@ -17,6 +19,71 @@ static bool isPotentiallyUnknownSymbolTable(Operation *op) { return !op->getDialect() && op->getNumRegions() == 1; } +/// Returns the nearest symbol table from a given operation `from`. Returns +/// nullptr if no valid parent symbol table could be found. +static Operation *getNearestSymbolTable(Operation *from) { + assert(from && "expected valid operation"); + if (isPotentiallyUnknownSymbolTable(from)) + return nullptr; + + while (!from->hasTrait<OpTrait::SymbolTable>()) { + from = from->getParentOp(); + + // Check that this is a valid op and isn't an unknown symbol table. + if (!from || isPotentiallyUnknownSymbolTable(from)) + return nullptr; + } + return from; +} + +/// Returns the string name of the given symbol, or None if this is not a +/// symbol. +static Optional<StringRef> getNameIfSymbol(Operation *symbol) { + auto nameAttr = + symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()); + return nameAttr ? nameAttr.getValue() : Optional<StringRef>(); +} + +/// Computes the nested symbol reference attribute for the symbol 'symbolName' +/// that are usable within the symbol table operations from 'symbol' as far up +/// to the given operation 'within', where 'within' is an ancestor of 'symbol'. +/// Returns success if all references up to 'within' could be computed. +static LogicalResult +collectValidReferencesFor(Operation *symbol, StringRef symbolName, + Operation *within, + SmallVectorImpl<SymbolRefAttr> &results) { + assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor"); + MLIRContext *ctx = symbol->getContext(); + + auto leafRef = FlatSymbolRefAttr::get(symbolName, ctx); + results.push_back(leafRef); + + // Early exit for when 'within' is the parent of 'symbol'. + Operation *symbolTableOp = symbol->getParentOp(); + if (within == symbolTableOp) + return success(); + + // Collect references until 'symbolTableOp' reaches 'within'. + SmallVector<FlatSymbolRefAttr, 1> nestedRefs(1, leafRef); + do { + // Each parent of 'symbol' should define a symbol table. + if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) + return failure(); + // Each parent of 'symbol' should also be a symbol. + Optional<StringRef> symbolTableName = getNameIfSymbol(symbolTableOp); + if (!symbolTableName) + return failure(); + results.push_back(SymbolRefAttr::get(*symbolTableName, nestedRefs, ctx)); + + symbolTableOp = symbolTableOp->getParentOp(); + if (symbolTableOp == within) + break; + nestedRefs.insert(nestedRefs.begin(), + FlatSymbolRefAttr::get(*symbolTableName, ctx)); + } while (true); + return success(); +} + //===----------------------------------------------------------------------===// // SymbolTable //===----------------------------------------------------------------------===// @@ -32,11 +99,11 @@ SymbolTable::SymbolTable(Operation *symbolTableOp) "expected operation to have a single block"); for (auto &op : symbolTableOp->getRegion(0).front()) { - auto nameAttr = op.getAttrOfType<StringAttr>(getSymbolAttrName()); - if (!nameAttr) + Optional<StringRef> name = getNameIfSymbol(&op); + if (!name) continue; - auto inserted = symbolTable.insert({nameAttr.getValue(), &op}); + auto inserted = symbolTable.insert({*name, &op}); (void)inserted; assert(inserted.second && "expected region to contain uniquely named symbol operations"); @@ -51,13 +118,13 @@ Operation *SymbolTable::lookup(StringRef name) const { /// Erase the given symbol from the table. void SymbolTable::erase(Operation *symbol) { - auto nameAttr = symbol->getAttrOfType<StringAttr>(getSymbolAttrName()); - assert(nameAttr && "expected valid 'name' attribute"); + Optional<StringRef> name = getNameIfSymbol(symbol); + assert(name && "expected valid 'name' attribute"); assert(symbol->getParentOp() == symbolTableOp && "expected this operation to be inside of the operation with this " "SymbolTable"); - auto it = symbolTable.find(nameAttr.getValue()); + auto it = symbolTable.find(*name); if (it != symbolTable.end() && it->second == symbol) { symbolTable.erase(it); symbol->erase(); @@ -67,9 +134,6 @@ void SymbolTable::erase(Operation *symbol) { /// Insert a new symbol into the table and associated operation, and rename it /// as necessary to avoid collisions. void SymbolTable::insert(Operation *symbol, Block::iterator insertPt) { - auto nameAttr = symbol->getAttrOfType<StringAttr>(getSymbolAttrName()); - assert(nameAttr && "expected valid 'name' attribute"); - auto &body = symbolTableOp->getRegion(0).front(); if (insertPt == Block::iterator() || insertPt == body.end()) insertPt = Block::iterator(body.getTerminator()); @@ -81,12 +145,12 @@ void SymbolTable::insert(Operation *symbol, Block::iterator insertPt) { // Add this symbol to the symbol table, uniquing the name if a conflict is // detected. - if (symbolTable.insert({nameAttr.getValue(), symbol}).second) + StringRef name = getSymbolName(symbol); + if (symbolTable.insert({name, symbol}).second) return; - // If a conflict was detected, then the symbol will not have been added to // the symbol table. Try suffixes until we get to a unique name that works. - SmallString<128> nameBuffer(nameAttr.getValue()); + SmallString<128> nameBuffer(name); unsigned originalLength = nameBuffer.size(); // Iteratively try suffixes until we find one that isn't used. @@ -95,8 +159,24 @@ void SymbolTable::insert(Operation *symbol, Block::iterator insertPt) { nameBuffer += '_'; nameBuffer += std::to_string(uniquingCounter++); } while (!symbolTable.insert({nameBuffer, symbol}).second); + setSymbolName(symbol, nameBuffer); +} + +/// Returns true if the given operation defines a symbol. +bool SymbolTable::isSymbol(Operation *op) { + return op->hasTrait<OpTrait::Symbol>() || getNameIfSymbol(op).hasValue(); +} + +/// Returns the name of the given symbol operation. +StringRef SymbolTable::getSymbolName(Operation *symbol) { + Optional<StringRef> name = getNameIfSymbol(symbol); + assert(name && "expected valid symbol name"); + return *name; +} +/// Sets the name of the given symbol operation. +void SymbolTable::setSymbolName(Operation *symbol, StringRef name) { symbol->setAttr(getSymbolAttrName(), - StringAttr::get(nameBuffer, symbolTableOp->getContext())); + StringAttr::get(name, symbol->getContext())); } /// Returns the operation registered with the given symbol name with the @@ -109,30 +189,52 @@ Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp, // Look for a symbol with the given name. for (auto &block : symbolTableOp->getRegion(0)) { - for (auto &op : block) { - auto nameAttr = op.template getAttrOfType<StringAttr>( - mlir::SymbolTable::getSymbolAttrName()); - if (nameAttr && nameAttr.getValue() == symbol) + for (auto &op : block) + if (getNameIfSymbol(&op) == symbol) return &op; - } } return nullptr; } +Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp, + SymbolRefAttr symbol) { + assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>()); + + // Lookup the root reference for this symbol. + symbolTableOp = lookupSymbolIn(symbolTableOp, symbol.getRootReference()); + if (!symbolTableOp) + return nullptr; + + // If there are no nested references, just return the root symbol directly. + ArrayRef<FlatSymbolRefAttr> nestedRefs = symbol.getNestedReferences(); + if (nestedRefs.empty()) + return symbolTableOp; + + // Verify that the root is also a symbol table. + if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) + return nullptr; + + // Otherwise, lookup each of the nested non-leaf references and ensure that + // each corresponds to a valid symbol table. + for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) { + symbolTableOp = lookupSymbolIn(symbolTableOp, ref.getValue()); + if (!symbolTableOp || !symbolTableOp->hasTrait<OpTrait::SymbolTable>()) + return nullptr; + } + return lookupSymbolIn(symbolTableOp, symbol.getLeafReference()); +} /// Returns the operation registered with the given symbol name within the /// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns /// nullptr if no valid symbol was found. Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from, StringRef symbol) { - assert(from && "expected valid operation"); - while (!from->hasTrait<OpTrait::SymbolTable>()) { - from = from->getParentOp(); - - // Check that this is a valid op and isn't an unknown symbol table. - if (!from || isPotentiallyUnknownSymbolTable(from)) - return nullptr; - } - return lookupSymbolIn(from, symbol); + Operation *symbolTableOp = getNearestSymbolTable(from); + return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr; +} +Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from, + SymbolRefAttr symbol) { + Operation *symbolTableOp = getNearestSymbolTable(from); + return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr; } //===----------------------------------------------------------------------===// @@ -148,7 +250,7 @@ LogicalResult OpTrait::impl::verifySymbolTable(Operation *op) { << "Operations with a 'SymbolTable' must have exactly one block"; // Check that all symbols are uniquely named within child regions. - llvm::StringMap<Location> nameToOrigLoc; + DenseMap<Attribute, Location> nameToOrigLoc; for (auto &block : op->getRegion(0)) { for (auto &op : block) { // Check for a symbol name attribute. @@ -158,7 +260,7 @@ LogicalResult OpTrait::impl::verifySymbolTable(Operation *op) { continue; // Try to insert this symbol into the table. - auto it = nameToOrigLoc.try_emplace(nameAttr.getValue(), op.getLoc()); + auto it = nameToOrigLoc.try_emplace(nameAttr, op.getLoc()); if (!it.second) return op.emitError() .append("redefinition of symbol named '", nameAttr.getValue(), "'") @@ -293,6 +395,100 @@ static Optional<WalkResult> walkSymbolUses( return WalkResult::advance(); } +/// Walks all of the symbol scopes from 'symbol' to (inclusive) 'limit' invoking +/// the provided callback at each one with a properly scoped reference to +/// 'symbol'. The callback takes as parameters the symbol reference at the +/// current scope as well as the top-level operation representing the top of +/// that scope. +static Optional<WalkResult> walkSymbolScopes( + Operation *symbol, Operation *limit, + function_ref<Optional<WalkResult>(SymbolRefAttr, Operation *)> callback) { + StringRef symbolName = SymbolTable::getSymbolName(symbol); + assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit); + + // Compute the ancestors of 'limit'. + llvm::SetVector<Operation *, SmallVector<Operation *, 4>, + SmallPtrSet<Operation *, 4>> + limitAncestors; + Operation *limitAncestor = limit; + do { + // Check to see if 'symbol' is an ancestor of 'limit'. + if (limitAncestor == symbol) { + // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr + // doesn't support parent references. + if (getNearestSymbolTable(limit) != symbol->getParentOp()) + return WalkResult::advance(); + return callback(SymbolRefAttr::get(symbolName, symbol->getContext()), + limit); + } + + limitAncestors.insert(limitAncestor); + } while ((limitAncestor = limitAncestor->getParentOp())); + + // Try to find the first ancestor of 'symbol' that is an ancestor of 'limit'. + Operation *commonAncestor = symbol->getParentOp(); + do { + if (limitAncestors.count(commonAncestor)) + break; + } while ((commonAncestor = commonAncestor->getParentOp())); + assert(commonAncestor && "'limit' and 'symbol' have no common ancestor"); + + // Compute the set of valid nested references for 'symbol' as far up to the + // common ancestor as possible. + SmallVector<SymbolRefAttr, 2> references; + bool collectedAllReferences = succeeded(collectValidReferencesFor( + symbol, symbolName, commonAncestor, references)); + + // Handle the case where the common ancestor is 'limit'. + if (commonAncestor == limit) { + // Walk each of the ancestors of 'symbol', calling the compute function for + // each one. + Operation *limitIt = symbol->getParentOp(); + for (size_t i = 0, e = references.size(); i != e; + ++i, limitIt = limitIt->getParentOp()) { + Optional<WalkResult> callbackResult = callback(references[i], limitIt); + if (callbackResult != WalkResult::advance()) + return callbackResult; + } + return WalkResult::advance(); + } + + // Otherwise, we just need the symbol reference for 'symbol' that will be + // used within 'limit'. This is the last reference in the list we computed + // above if we were able to collect all references. + if (!collectedAllReferences) + return WalkResult::advance(); + return callback(references.back(), limit); +} + +/// Walk the symbol scopes defined by 'limit' invoking the provided callback. +static Optional<WalkResult> walkSymbolScopes( + StringRef symbol, Operation *limit, + function_ref<Optional<WalkResult>(SymbolRefAttr, Operation *)> callback) { + return callback(SymbolRefAttr::get(symbol, limit->getContext()), limit); +} + +/// Returns true if the given reference 'SubRef' is a sub reference of the +/// reference 'ref', i.e. 'ref' is a further qualified reference. +static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) { + if (ref == subRef) + return true; + + // If the references are not pointer equal, check to see if `subRef` is a + // prefix of `ref`. + if (ref.isa<FlatSymbolRefAttr>() || + ref.getRootReference() != subRef.getRootReference()) + return false; + + auto refLeafs = ref.getNestedReferences(); + auto subRefLeafs = subRef.getNestedReferences(); + return subRefLeafs.size() < refLeafs.size() && + subRefLeafs == refLeafs.take_front(subRefLeafs.size()); +} + +//===----------------------------------------------------------------------===// +// SymbolTable::getSymbolUses + /// Get an iterator range for all of the uses, for any symbol, that are nested /// within the given operation 'from'. This does not traverse into any nested /// symbol tables, and will also only return uses on 'from' if it does not @@ -302,14 +498,35 @@ static Optional<WalkResult> walkSymbolUses( /// tables. auto SymbolTable::getSymbolUses(Operation *from) -> Optional<UseRange> { std::vector<SymbolUse> uses; - Optional<WalkResult> result = - walkSymbolUses(from, [&](SymbolUse symbolUse, ArrayRef<int>) { - uses.push_back(symbolUse); - return WalkResult::advance(); - }); + auto walkFn = [&](SymbolUse symbolUse, ArrayRef<int>) { + uses.push_back(symbolUse); + return WalkResult::advance(); + }; + auto result = walkSymbolUses(from, walkFn); return result ? Optional<UseRange>(std::move(uses)) : Optional<UseRange>(); } +//===----------------------------------------------------------------------===// +// SymbolTable::getSymbolUses + +/// The implementation of SymbolTable::getSymbolUses below. +template <typename SymbolT> +static Optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol, + Operation *limit) { + std::vector<SymbolTable::SymbolUse> uses; + auto walkFn = [&](SymbolRefAttr symbolRefAttr, Operation *from) { + return walkSymbolUses( + from, [&](SymbolTable::SymbolUse symbolUse, ArrayRef<int>) { + if (isReferencePrefixOf(symbolRefAttr, symbolUse.getSymbolRef())) + uses.push_back(symbolUse); + return WalkResult::advance(); + }); + }; + if (walkSymbolScopes(symbol, limit, walkFn)) + return SymbolTable::UseRange(std::move(uses)); + return llvm::None; +} + /// Get all of the uses of the given symbol that are nested within the given /// operation 'from', invoking the provided callback for each. This does not /// traverse into any nested symbol tables, and will also only return uses on @@ -319,16 +536,29 @@ auto SymbolTable::getSymbolUses(Operation *from) -> Optional<UseRange> { /// potentially be symbol tables. auto SymbolTable::getSymbolUses(StringRef symbol, Operation *from) -> Optional<UseRange> { - SymbolRefAttr symbolRefAttr = SymbolRefAttr::get(symbol, from->getContext()); + return getSymbolUsesImpl(symbol, from); +} +auto SymbolTable::getSymbolUses(Operation *symbol, Operation *from) + -> Optional<UseRange> { + return getSymbolUsesImpl(symbol, from); +} - std::vector<SymbolUse> uses; - Optional<WalkResult> result = - walkSymbolUses(from, [&](SymbolUse symbolUse, ArrayRef<int>) { - if (symbolRefAttr == symbolUse.getSymbolRef()) - uses.push_back(symbolUse); - return WalkResult::advance(); - }); - return result ? Optional<UseRange>(std::move(uses)) : Optional<UseRange>(); +//===----------------------------------------------------------------------===// +// SymbolTable::symbolKnownUseEmpty + +/// The implementation of SymbolTable::symbolKnownUseEmpty below. +template <typename SymbolT> +static bool symbolKnownUseEmptyImpl(SymbolT symbol, Operation *limit) { + // Walk all of the symbol uses looking for a reference to 'symbol'. + auto walkFn = [&](SymbolRefAttr symbolRefAttr, Operation *from) { + return walkSymbolUses( + from, [&](SymbolTable::SymbolUse symbolUse, ArrayRef<int>) { + return isReferencePrefixOf(symbolRefAttr, symbolUse.getSymbolRef()) + ? WalkResult::interrupt() + : WalkResult::advance(); + }); + }; + return walkSymbolScopes(symbol, limit, walkFn) == WalkResult::advance(); } /// Return if the given symbol is known to have no uses that are nested within @@ -338,35 +568,32 @@ auto SymbolTable::getSymbolUses(StringRef symbol, Operation *from) /// symbol table, and not the op itself. This function will also return false if /// there are any unknown operations that may potentially be symbol tables. bool SymbolTable::symbolKnownUseEmpty(StringRef symbol, Operation *from) { - SymbolRefAttr symbolRefAttr = SymbolRefAttr::get(symbol, from->getContext()); - - // Walk all of the symbol uses looking for a reference to 'symbol'. - Optional<WalkResult> walkResult = - walkSymbolUses(from, [&](SymbolUse symbolUse, ArrayRef<int>) { - return symbolUse.getSymbolRef() == symbolRefAttr - ? WalkResult::interrupt() - : WalkResult::advance(); - }); - return walkResult && !walkResult->wasInterrupted(); + return symbolKnownUseEmptyImpl(symbol, from); +} +bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Operation *from) { + return symbolKnownUseEmptyImpl(symbol, from); } +//===----------------------------------------------------------------------===// +// SymbolTable::replaceAllSymbolUses + /// Rebuild the given attribute container after replacing all references to a -/// symbol with `newSymAttr`. -static Attribute rebuildAttrAfterRAUW(Attribute container, - ArrayRef<SmallVector<int, 1>> accesses, - SymbolRefAttr newSymAttr, - unsigned depth) { +/// symbol with the updated attribute in 'accesses'. +static Attribute rebuildAttrAfterRAUW( + Attribute container, + ArrayRef<std::pair<SmallVector<int, 1>, SymbolRefAttr>> accesses, + unsigned depth) { // Given a range of Attributes, update the ones referred to by the given // access chains to point to the new symbol attribute. auto updateAttrs = [&](auto &&attrRange) { auto attrBegin = std::begin(attrRange); for (unsigned i = 0, e = accesses.size(); i != e;) { - ArrayRef<int> access = accesses[i]; + ArrayRef<int> access = accesses[i].first; Attribute &attr = *std::next(attrBegin, access[depth]); // Check to see if this is a leaf access, i.e. a SymbolRef. if (access.size() == depth + 1) { - attr = newSymAttr; + attr = accesses[i].second; ++i; continue; } @@ -374,12 +601,12 @@ static Attribute rebuildAttrAfterRAUW(Attribute container, // Otherwise, this is a container. Collect all of the accesses for this // index and recurse. The recursion here is bounded by the size of the // largest access array. - auto nestedAccesses = - accesses.drop_front(i).take_while([&](ArrayRef<int> nextAccess) { - return nextAccess.size() > depth + 1 && - nextAccess[depth] == access[depth]; - }); - attr = rebuildAttrAfterRAUW(attr, nestedAccesses, newSymAttr, depth + 1); + auto nestedAccesses = accesses.drop_front(i).take_while([&](auto &it) { + ArrayRef<int> nextAccess = it.first; + return nextAccess.size() > depth + 1 && + nextAccess[depth] == access[depth]; + }); + attr = rebuildAttrAfterRAUW(attr, nestedAccesses, depth + 1); // Skip over all of the accesses that refer to the nested container. i += nestedAccesses.size(); @@ -396,64 +623,114 @@ static Attribute rebuildAttrAfterRAUW(Attribute container, return ArrayAttr::get(newAttrs, container.getContext()); } -/// Attempt to replace all uses of the given symbol 'oldSymbol' with the -/// provided symbol 'newSymbol' that are nested within the given operation -/// 'from'. This does not traverse into any nested symbol tables, and will -/// also only replace uses on 'from' if it does not also define a symbol -/// table. This is because we treat the region as the boundary of the symbol -/// table, and not the op itself. If there are any unknown operations that may -/// potentially be symbol tables, no uses are replaced and failure is returned. -LogicalResult SymbolTable::replaceAllSymbolUses(StringRef oldSymbol, - StringRef newSymbol, - Operation *from) { - SymbolRefAttr oldAttr = SymbolRefAttr::get(oldSymbol, from->getContext()); - SymbolRefAttr newSymAttr = SymbolRefAttr::get(newSymbol, from->getContext()); +/// Generates a new symbol reference attribute with a new leaf reference. +SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr, + FlatSymbolRefAttr newLeafAttr) { + if (oldAttr.isa<FlatSymbolRefAttr>()) + return newLeafAttr; + auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences()); + nestedRefs.back() = newLeafAttr; + return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs, + oldAttr.getContext()); +} +/// The implementation of SymbolTable::replaceAllSymbolUses below. +template <typename SymbolT> +static LogicalResult replaceAllSymbolUsesImpl(SymbolT symbol, + StringRef newSymbol, + Operation *limit) { // A collection of operations along with their new attribute dictionary. std::vector<std::pair<Operation *, DictionaryAttr>> updatedAttrDicts; - // The current operation, and its old symbol access chains, being processed. + // The current operation being processed. Operation *curOp = nullptr; - SmallVector<SmallVector<int, 1>, 1> accessChains; + + // The set of access chains into the attribute dictionary of the current + // operation, as well as the replacement attribute to use. + SmallVector<std::pair<SmallVector<int, 1>, SymbolRefAttr>, 1> accessChains; // Generate a new attribute dictionary for the current operation by replacing // references to the old symbol. auto generateNewAttrDict = [&] { - auto newAttrDict = - rebuildAttrAfterRAUW(curOp->getAttrList().getDictionary(), accessChains, - newSymAttr, /*depth=*/0); - return newAttrDict.cast<DictionaryAttr>(); + auto oldDict = curOp->getAttrList().getDictionary(); + auto newDict = rebuildAttrAfterRAUW(oldDict, accessChains, /*depth=*/0); + return newDict.cast<DictionaryAttr>(); }; - // Walk the symbol uses collecting uses of the old symbol. - auto walkFn = [&](SymbolTable::SymbolUse symbolUse, - ArrayRef<int> accessChain) { - if (symbolUse.getSymbolRef() != oldAttr) + // Generate a new attribute to replace the given attribute. + MLIRContext *ctx = limit->getContext(); + FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol, ctx); + auto scopeWalkFn = [&](SymbolRefAttr oldAttr, + Operation *from) -> Optional<WalkResult> { + SymbolRefAttr newAttr = generateNewRefAttr(oldAttr, newLeafAttr); + auto walkFn = [&](SymbolTable::SymbolUse symbolUse, + ArrayRef<int> accessChain) { + SymbolRefAttr useRef = symbolUse.getSymbolRef(); + if (!isReferencePrefixOf(oldAttr, useRef)) + return WalkResult::advance(); + + // If we have a valid match, check to see if this is a proper + // subreference. If it is, then we will need to generate a different new + // attribute specifically for this use. + SymbolRefAttr replacementRef = newAttr; + if (useRef != oldAttr) { + if (oldAttr.isa<FlatSymbolRefAttr>()) { + replacementRef = + SymbolRefAttr::get(newSymbol, useRef.getNestedReferences(), ctx); + } else { + auto nestedRefs = llvm::to_vector<4>(useRef.getNestedReferences()); + nestedRefs[oldAttr.getNestedReferences().size() - 1] = newLeafAttr; + replacementRef = + SymbolRefAttr::get(useRef.getRootReference(), nestedRefs, ctx); + } + } + + // If there was a previous operation, generate a new attribute dict + // for it. This means that we've finished processing the current + // operation, so generate a new dictionary for it. + if (curOp && symbolUse.getUser() != curOp) { + updatedAttrDicts.push_back({curOp, generateNewAttrDict()}); + accessChains.clear(); + } + + // Record this access. + curOp = symbolUse.getUser(); + accessChains.push_back({llvm::to_vector<1>(accessChain), replacementRef}); return WalkResult::advance(); + }; + if (!walkSymbolUses(from, walkFn)) + return llvm::None; - // If there was a previous operation, generate a new attribute dict for it. - // This means that we've finished processing the current operation, so - // generate a new dictionary for it. - if (curOp && symbolUse.getUser() != curOp) { + // Check to see if we have a dangling op that needs to be processed. + if (curOp) { updatedAttrDicts.push_back({curOp, generateNewAttrDict()}); - accessChains.clear(); + curOp = nullptr; } - - // Record this access. - curOp = symbolUse.getUser(); - accessChains.push_back(llvm::to_vector<1>(accessChain)); return WalkResult::advance(); }; - if (!walkSymbolUses(from, walkFn)) + if (!walkSymbolScopes(symbol, limit, scopeWalkFn)) return failure(); // Update the attribute dictionaries as necessary. for (auto &it : updatedAttrDicts) it.first->setAttrs(it.second); - - // Check to see if we have a dangling op that needs to be processed. - if (curOp) - curOp->setAttrs(generateNewAttrDict()); - return success(); } + +/// Attempt to replace all uses of the given symbol 'oldSymbol' with the +/// provided symbol 'newSymbol' that are nested within the given operation +/// 'from'. This does not traverse into any nested symbol tables, and will +/// also only replace uses on 'from' if it does not also define a symbol +/// table. This is because we treat the region as the boundary of the symbol +/// table, and not the op itself. If there are any unknown operations that may +/// potentially be symbol tables, no uses are replaced and failure is returned. +LogicalResult SymbolTable::replaceAllSymbolUses(StringRef oldSymbol, + StringRef newSymbol, + Operation *from) { + return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); +} +LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol, + StringRef newSymbol, + Operation *from) { + return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); +} |