[Mlir-commits] [mlir] 27b7392 - Add an operator == and != to properties, use it in DuplicateFunctionElimination

Mehdi Amini llvmlistbot at llvm.org
Mon May 15 12:10:52 PDT 2023


Author: Mehdi Amini
Date: 2023-05-15T12:10:46-07:00
New Revision: 27b739228b42ecc7c15664670859e2bbad3b7749

URL: https://github.com/llvm/llvm-project/commit/27b739228b42ecc7c15664670859e2bbad3b7749
DIFF: https://github.com/llvm/llvm-project/commit/27b739228b42ecc7c15664670859e2bbad3b7749.diff

LOG: Add an operator == and != to properties, use it in DuplicateFunctionElimination

Differential Revision: https://reviews.llvm.org/D150596

Added: 
    

Modified: 
    mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
    mlir/test/lib/Dialect/Test/TestDialect.h
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
index b83d67e2ef14a..d41d6c3e8972f 100644
--- a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp
@@ -48,33 +48,28 @@ struct DuplicateFuncOpEquivalenceInfo
     return hash;
   }
 
-  static bool isEqual(const func::FuncOp cLhs, const func::FuncOp cRhs) {
-    if (cLhs == cRhs) {
+  static bool isEqual(func::FuncOp lhs, func::FuncOp rhs) {
+    if (lhs == rhs)
       return true;
-    }
-    if (cLhs == getTombstoneKey() || cLhs == getEmptyKey() ||
-        cRhs == getTombstoneKey() || cRhs == getEmptyKey()) {
+    if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
+        rhs == getTombstoneKey() || rhs == getEmptyKey())
+      return false;
+    // Check discardable attributes equivalence
+    if (lhs->getDiscardableAttrDictionary() !=
+        rhs->getDiscardableAttrDictionary())
       return false;
-    }
 
-    // Check attributes equivalence, ignoring the symbol name.
-    if (cLhs->getAttrDictionary().size() != cRhs->getAttrDictionary().size()) {
+    // Check properties equivalence, ignoring the symbol name.
+    // Make a copy, so that we can erase the symbol name and perform the
+    // comparison.
+    auto pLhs = lhs.getProperties();
+    auto pRhs = rhs.getProperties();
+    pLhs.sym_name = nullptr;
+    pRhs.sym_name = nullptr;
+    if (pLhs != pRhs)
       return false;
-    }
-    func::FuncOp lhs = const_cast<func::FuncOp &>(cLhs);
-    StringAttr symNameAttrName = lhs.getSymNameAttrName();
-    for (NamedAttribute namedAttr : cLhs->getAttrs()) {
-      StringAttr attrName = namedAttr.getName();
-      if (attrName == symNameAttrName) {
-        continue;
-      }
-      if (namedAttr.getValue() != cRhs->getAttr(attrName)) {
-        return false;
-      }
-    }
 
     // Compare inner workings.
-    func::FuncOp rhs = const_cast<func::FuncOp &>(cRhs);
     return OperationEquivalence::isRegionEquivalentTo(
         &lhs.getBody(), &rhs.getBody(), OperationEquivalence::IgnoreLocations);
   }

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 36e967d2d5788..164e2e024423c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -395,7 +395,7 @@ struct ConvertSelectionOpToSelect
   bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
     return lhs->getDiscardableAttrDictionary() ==
                rhs->getDiscardableAttrDictionary() &&
-           lhs->hashProperties() == rhs->hashProperties();
+           lhs.getProperties() == rhs.getProperties();
   }
 
   // Returns a source value for the given block.

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index dcca76d7dc388..34936783d62ae 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -66,6 +66,9 @@ struct PropertiesWithCustomPrint {
   /// offloaded to the client.
   std::shared_ptr<const std::string> label;
   int value;
+  bool operator==(const PropertiesWithCustomPrint &rhs) const {
+    return value == rhs.value && *label == *rhs.label;
+  }
 };
 class MyPropStruct {
 public:
@@ -77,6 +80,9 @@ class MyPropStruct {
                                          mlir::Attribute attr,
                                          mlir::InFlightDiagnostic *diag);
   llvm::hash_code hash() const;
+  bool operator==(const MyPropStruct &rhs) const {
+    return content == rhs.content;
+  }
 };
 } // namespace test
 

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index b090de958faff..c287f5254c1f7 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -3433,6 +3433,10 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
     assert(!attrOrProperties.empty());
     std::string declarations = "  struct Properties {\n";
     llvm::raw_string_ostream os(declarations);
+    std::string comparator =
+        "    bool operator==(const Properties &rhs) const {\n"
+        "      return \n";
+    llvm::raw_string_ostream comparatorOs(comparator);
     for (const auto &attrOrProp : attrOrProperties) {
       if (const auto *namedProperty =
               attrOrProp.dyn_cast<const NamedProperty *>()) {
@@ -3447,7 +3451,8 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
            << "    " << name << "Ty " << name;
         if (prop.hasDefaultValue())
           os << " = " << prop.getDefaultValue();
-
+        comparatorOs << "        rhs." << name << " == this->" << name
+                     << " &&\n";
         // Emit accessors using the interface type.
         const char *accessorFmt = R"decl(;
     {0} get{1}() {
@@ -3490,6 +3495,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
       }
       os << "    using " << name << "Ty = " << storageType << ";\n"
          << "    " << name << "Ty " << name << ";\n";
+      comparatorOs << "        rhs." << name << " == this->" << name << " &&\n";
 
       // Emit accessors using the interface type.
       if (attr) {
@@ -3509,8 +3515,15 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
                       storageType);
       }
     }
+    comparatorOs << "        true;\n    }\n"
+                    "    bool operator!=(const Properties &rhs) const {\n"
+                    "      return !(*this == rhs);\n"
+                    "    }\n";
+    comparatorOs.flush();
+    os << comparator;
     os << "  };\n";
     os.flush();
+
     genericAdaptorBase.declare<ExtraClassDeclaration>(std::move(declarations));
   }
   genericAdaptorBase.declare<VisibilityDeclaration>(Visibility::Protected);


        


More information about the Mlir-commits mailing list