summaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorMehdi Amini <joker.eph@gmail.com>2023-05-15 11:03:24 -0700
committerMehdi Amini <joker.eph@gmail.com>2023-05-15 12:10:46 -0700
commit27b739228b42ecc7c15664670859e2bbad3b7749 (patch)
treea7bcfa155deab061b72caf42fb30a96114656606 /mlir
parent86c7e33b3fd0cc231b09b5af21ef42842f0ff97b (diff)
downloadllvm-27b739228b42ecc7c15664670859e2bbad3b7749.tar.gz
Add an operator == and != to properties, use it in DuplicateFunctionElimination
Differential Revision: https://reviews.llvm.org/D150596
Diffstat (limited to 'mlir')
-rw-r--r--mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp37
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp2
-rw-r--r--mlir/test/lib/Dialect/Test/TestDialect.h6
-rw-r--r--mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp15
4 files changed, 37 insertions, 23 deletions
diff --git a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
index b83d67e2ef14..d41d6c3e8972 100644
--- a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
@@ -48,33 +48,28 @@ struct DuplicateFuncOpEquivalenceInfo
return hash;
}
- static bool isEqual(const func::FuncOp cLhs, const func::FuncOp cRhs) {
- if (cLhs == cRhs) {
+ static bool isEqual(func::FuncOp lhs, func::FuncOp rhs) {
+ if (lhs == rhs)
return true;
- }
- if (cLhs == getTombstoneKey() || cLhs == getEmptyKey() ||
- cRhs == getTombstoneKey() || cRhs == getEmptyKey()) {
+ if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
+ rhs == getTombstoneKey() || rhs == getEmptyKey())
+ return false;
+ // Check discardable attributes equivalence
+ if (lhs->getDiscardableAttrDictionary() !=
+ rhs->getDiscardableAttrDictionary())
return false;
- }
- // Check attributes equivalence, ignoring the symbol name.
- if (cLhs->getAttrDictionary().size() != cRhs->getAttrDictionary().size()) {
+ // Check properties equivalence, ignoring the symbol name.
+ // Make a copy, so that we can erase the symbol name and perform the
+ // comparison.
+ auto pLhs = lhs.getProperties();
+ auto pRhs = rhs.getProperties();
+ pLhs.sym_name = nullptr;
+ pRhs.sym_name = nullptr;
+ if (pLhs != pRhs)
return false;
- }
- func::FuncOp lhs = const_cast<func::FuncOp &>(cLhs);
- StringAttr symNameAttrName = lhs.getSymNameAttrName();
- for (NamedAttribute namedAttr : cLhs->getAttrs()) {
- StringAttr attrName = namedAttr.getName();
- if (attrName == symNameAttrName) {
- continue;
- }
- if (namedAttr.getValue() != cRhs->getAttr(attrName)) {
- return false;
- }
- }
// Compare inner workings.
- func::FuncOp rhs = const_cast<func::FuncOp &>(cRhs);
return OperationEquivalence::isRegionEquivalentTo(
&lhs.getBody(), &rhs.getBody(), OperationEquivalence::IgnoreLocations);
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 36e967d2d578..164e2e024423 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -395,7 +395,7 @@ private:
bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
return lhs->getDiscardableAttrDictionary() ==
rhs->getDiscardableAttrDictionary() &&
- lhs->hashProperties() == rhs->hashProperties();
+ lhs.getProperties() == rhs.getProperties();
}
// Returns a source value for the given block.
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index dcca76d7dc38..34936783d62a 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -66,6 +66,9 @@ struct PropertiesWithCustomPrint {
/// offloaded to the client.
std::shared_ptr<const std::string> label;
int value;
+ bool operator==(const PropertiesWithCustomPrint &rhs) const {
+ return value == rhs.value && *label == *rhs.label;
+ }
};
class MyPropStruct {
public:
@@ -77,6 +80,9 @@ public:
mlir::Attribute attr,
mlir::InFlightDiagnostic *diag);
llvm::hash_code hash() const;
+ bool operator==(const MyPropStruct &rhs) const {
+ return content == rhs.content;
+ }
};
} // namespace test
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index b090de958faf..c287f5254c1f 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -3433,6 +3433,10 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
assert(!attrOrProperties.empty());
std::string declarations = " struct Properties {\n";
llvm::raw_string_ostream os(declarations);
+ std::string comparator =
+ " bool operator==(const Properties &rhs) const {\n"
+ " return \n";
+ llvm::raw_string_ostream comparatorOs(comparator);
for (const auto &attrOrProp : attrOrProperties) {
if (const auto *namedProperty =
attrOrProp.dyn_cast<const NamedProperty *>()) {
@@ -3447,7 +3451,8 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
<< " " << name << "Ty " << name;
if (prop.hasDefaultValue())
os << " = " << prop.getDefaultValue();
-
+ comparatorOs << " rhs." << name << " == this->" << name
+ << " &&\n";
// Emit accessors using the interface type.
const char *accessorFmt = R"decl(;
{0} get{1}() {
@@ -3490,6 +3495,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
}
os << " using " << name << "Ty = " << storageType << ";\n"
<< " " << name << "Ty " << name << ";\n";
+ comparatorOs << " rhs." << name << " == this->" << name << " &&\n";
// Emit accessors using the interface type.
if (attr) {
@@ -3509,8 +3515,15 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
storageType);
}
}
+ comparatorOs << " true;\n }\n"
+ " bool operator!=(const Properties &rhs) const {\n"
+ " return !(*this == rhs);\n"
+ " }\n";
+ comparatorOs.flush();
+ os << comparator;
os << " };\n";
os.flush();
+
genericAdaptorBase.declare<ExtraClassDeclaration>(std::move(declarations));
}
genericAdaptorBase.declare<VisibilityDeclaration>(Visibility::Protected);