diff options
author | Mehdi Amini <joker.eph@gmail.com> | 2023-05-15 11:03:24 -0700 |
---|---|---|
committer | Mehdi Amini <joker.eph@gmail.com> | 2023-05-15 12:10:46 -0700 |
commit | 27b739228b42ecc7c15664670859e2bbad3b7749 (patch) | |
tree | a7bcfa155deab061b72caf42fb30a96114656606 /mlir | |
parent | 86c7e33b3fd0cc231b09b5af21ef42842f0ff97b (diff) | |
download | llvm-27b739228b42ecc7c15664670859e2bbad3b7749.tar.gz |
Add an operator == and != to properties, use it in DuplicateFunctionElimination
Differential Revision: https://reviews.llvm.org/D150596
Diffstat (limited to 'mlir')
4 files changed, 37 insertions, 23 deletions
diff --git a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp index b83d67e2ef14..d41d6c3e8972 100644 --- a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp +++ b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp @@ -48,33 +48,28 @@ struct DuplicateFuncOpEquivalenceInfo return hash; } - static bool isEqual(const func::FuncOp cLhs, const func::FuncOp cRhs) { - if (cLhs == cRhs) { + static bool isEqual(func::FuncOp lhs, func::FuncOp rhs) { + if (lhs == rhs) return true; - } - if (cLhs == getTombstoneKey() || cLhs == getEmptyKey() || - cRhs == getTombstoneKey() || cRhs == getEmptyKey()) { + if (lhs == getTombstoneKey() || lhs == getEmptyKey() || + rhs == getTombstoneKey() || rhs == getEmptyKey()) + return false; + // Check discardable attributes equivalence + if (lhs->getDiscardableAttrDictionary() != + rhs->getDiscardableAttrDictionary()) return false; - } - // Check attributes equivalence, ignoring the symbol name. - if (cLhs->getAttrDictionary().size() != cRhs->getAttrDictionary().size()) { + // Check properties equivalence, ignoring the symbol name. + // Make a copy, so that we can erase the symbol name and perform the + // comparison. + auto pLhs = lhs.getProperties(); + auto pRhs = rhs.getProperties(); + pLhs.sym_name = nullptr; + pRhs.sym_name = nullptr; + if (pLhs != pRhs) return false; - } - func::FuncOp lhs = const_cast<func::FuncOp &>(cLhs); - StringAttr symNameAttrName = lhs.getSymNameAttrName(); - for (NamedAttribute namedAttr : cLhs->getAttrs()) { - StringAttr attrName = namedAttr.getName(); - if (attrName == symNameAttrName) { - continue; - } - if (namedAttr.getValue() != cRhs->getAttr(attrName)) { - return false; - } - } // Compare inner workings. - func::FuncOp rhs = const_cast<func::FuncOp &>(cRhs); return OperationEquivalence::isRegionEquivalentTo( &lhs.getBody(), &rhs.getBody(), OperationEquivalence::IgnoreLocations); } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp index 36e967d2d578..164e2e024423 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -395,7 +395,7 @@ private: bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const { return lhs->getDiscardableAttrDictionary() == rhs->getDiscardableAttrDictionary() && - lhs->hashProperties() == rhs->hashProperties(); + lhs.getProperties() == rhs.getProperties(); } // Returns a source value for the given block. diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h index dcca76d7dc38..34936783d62a 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -66,6 +66,9 @@ struct PropertiesWithCustomPrint { /// offloaded to the client. std::shared_ptr<const std::string> label; int value; + bool operator==(const PropertiesWithCustomPrint &rhs) const { + return value == rhs.value && *label == *rhs.label; + } }; class MyPropStruct { public: @@ -77,6 +80,9 @@ public: mlir::Attribute attr, mlir::InFlightDiagnostic *diag); llvm::hash_code hash() const; + bool operator==(const MyPropStruct &rhs) const { + return content == rhs.content; + } }; } // namespace test diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index b090de958faf..c287f5254c1f 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -3433,6 +3433,10 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( assert(!attrOrProperties.empty()); std::string declarations = " struct Properties {\n"; llvm::raw_string_ostream os(declarations); + std::string comparator = + " bool operator==(const Properties &rhs) const {\n" + " return \n"; + llvm::raw_string_ostream comparatorOs(comparator); for (const auto &attrOrProp : attrOrProperties) { if (const auto *namedProperty = attrOrProp.dyn_cast<const NamedProperty *>()) { @@ -3447,7 +3451,8 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( << " " << name << "Ty " << name; if (prop.hasDefaultValue()) os << " = " << prop.getDefaultValue(); - + comparatorOs << " rhs." << name << " == this->" << name + << " &&\n"; // Emit accessors using the interface type. const char *accessorFmt = R"decl(; {0} get{1}() { @@ -3490,6 +3495,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( } os << " using " << name << "Ty = " << storageType << ";\n" << " " << name << "Ty " << name << ";\n"; + comparatorOs << " rhs." << name << " == this->" << name << " &&\n"; // Emit accessors using the interface type. if (attr) { @@ -3509,8 +3515,15 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( storageType); } } + comparatorOs << " true;\n }\n" + " bool operator!=(const Properties &rhs) const {\n" + " return !(*this == rhs);\n" + " }\n"; + comparatorOs.flush(); + os << comparator; os << " };\n"; os.flush(); + genericAdaptorBase.declare<ExtraClassDeclaration>(std::move(declarations)); } genericAdaptorBase.declare<VisibilityDeclaration>(Visibility::Protected); |