[Mlir-commits] [mlir] [mlir] add a way to query non-property attributes (PR #76959)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 4 06:06:01 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Oleksandr "Alex" Zinenko (ftynse)

<details>
<summary>Changes</summary>

This helps support generic manipulation of operations that don't (yet) use properties to store inherent attributes.

Use this mechanism in type inference and operation equivalence.

Note that only minimal unit tests are introduced as all the upstream dialects seem to have been updated to use properties and the non-property behavior is essentially deprecated and untested.

---
Full diff: https://github.com/llvm/llvm-project/pull/76959.diff


5 Files Affected:

- (modified) mlir/include/mlir/IR/Operation.h (+6) 
- (modified) mlir/lib/IR/OperationSupport.cpp (+5-5) 
- (modified) mlir/lib/Interfaces/InferTypeOpInterface.cpp (+2-3) 
- (modified) mlir/unittests/Bytecode/BytecodeTest.cpp (+3) 
- (modified) mlir/unittests/IR/OpPropertiesTest.cpp (+10) 


``````````diff
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index d2f52cf1afee81..546ac0b118b552 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -500,6 +500,12 @@ class alignas(8) Operation final
                                llvm::to_vector(getDiscardableAttrs()));
   }
 
+  /// Return all attributes that are not stored as properties.
+  DictionaryAttr getNonPropertyAttrDictionary() {
+    return getPropertiesStorage() ? getDiscardableAttrDictionary()
+                                  : getAttrDictionary();
+  }
+
   /// Return all of the attributes on this operation.
   ArrayRef<NamedAttribute> getAttrs() { return getAttrDictionary().getValue(); }
 
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index eaad6f2891608f..2bd7ac35024410 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -675,7 +675,7 @@ llvm::hash_code OperationEquivalence::computeHash(
   //   - Attributes
   //   - Result Types
   llvm::hash_code hash =
-      llvm::hash_combine(op->getName(), op->getDiscardableAttrDictionary(),
+      llvm::hash_combine(op->getName(), op->getNonPropertyAttrDictionary(),
                          op->getResultTypes(), op->hashProperties());
 
   //   - Location if required
@@ -831,14 +831,14 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
 
   // 1. Compare the operation properties.
   if (lhs->getName() != rhs->getName() ||
-      lhs->getDiscardableAttrDictionary() !=
-          rhs->getDiscardableAttrDictionary() ||
+      lhs->getNonPropertyAttrDictionary() !=
+          rhs->getNonPropertyAttrDictionary() ||
       lhs->getNumRegions() != rhs->getNumRegions() ||
       lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
       lhs->getNumOperands() != rhs->getNumOperands() ||
       lhs->getNumResults() != rhs->getNumResults() ||
       !lhs->getName().compareOpProperties(lhs->getPropertiesStorage(),
-                                        rhs->getPropertiesStorage()))
+                                          rhs->getPropertiesStorage()))
     return false;
   if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
     return false;
@@ -923,7 +923,7 @@ OperationFingerPrint::OperationFingerPrint(Operation *topOp) {
     if (op != topOp)
       addDataToHash(hasher, op->getParentOp());
     //   - Attributes
-    addDataToHash(hasher, op->getDiscardableAttrDictionary());
+    addDataToHash(hasher, op->getNonPropertyAttrDictionary());
     //   - Properties
     addDataToHash(hasher, op->hashProperties());
     //   - Blocks in Regions
diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index ee4c0519b9f54f..a24375164f8498 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -240,9 +240,8 @@ LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
   auto retTypeFn = cast<InferTypeOpInterface>(op);
   auto result = retTypeFn.refineReturnTypes(
       op->getContext(), op->getLoc(), op->getOperands(),
-      op->getPropertiesStorage() ? op->getDiscardableAttrDictionary()
-                                 : op->getAttrDictionary(),
-      op->getPropertiesStorage(), op->getRegions(), inferredReturnTypes);
+      op->getNonPropertyAttrDictionary(), op->getPropertiesStorage(),
+      op->getRegions(), inferredReturnTypes);
   if (failed(result))
     op->emitOpError() << "failed to infer returned types";
 
diff --git a/mlir/unittests/Bytecode/BytecodeTest.cpp b/mlir/unittests/Bytecode/BytecodeTest.cpp
index 76ff1a8194db74..bb7241c2d51969 100644
--- a/mlir/unittests/Bytecode/BytecodeTest.cpp
+++ b/mlir/unittests/Bytecode/BytecodeTest.cpp
@@ -144,4 +144,7 @@ TEST(Bytecode, OpWithoutProperties) {
   EXPECT_EQ(roundtripped->getAttrs().size(), 2u);
   EXPECT_TRUE(roundtripped->getInherentAttr("inherent_attr") != std::nullopt);
   EXPECT_TRUE(roundtripped->getDiscardableAttr("other_attr") != Attribute());
+
+  EXPECT_TRUE(OperationEquivalence::computeHash(op.get()) ==
+              OperationEquivalence::computeHash(roundtripped));
 }
diff --git a/mlir/unittests/IR/OpPropertiesTest.cpp b/mlir/unittests/IR/OpPropertiesTest.cpp
index bb1b741d1cc223..365775d541ec3d 100644
--- a/mlir/unittests/IR/OpPropertiesTest.cpp
+++ b/mlir/unittests/IR/OpPropertiesTest.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OperationSupport.h"
 #include "mlir/Parser/Parser.h"
 #include "gtest/gtest.h"
 #include <optional>
@@ -401,6 +402,15 @@ TEST(OpPropertiesTest, withoutPropertiesDiscardableAttrs) {
   op->print(os);
   EXPECT_TRUE(StringRef(os.str()).contains("inherent_attr = 42"));
   EXPECT_TRUE(StringRef(os.str()).contains("other_attr = 56"));
+
+  OwningOpRef<Operation *> reparsed = parseSourceString(os.str(), config);
+  auto trivialHash = [](Value v) { return hash_value(v); };
+  auto hash = [&](Operation *operation) {
+    return OperationEquivalence::computeHash(
+        operation, trivialHash, trivialHash,
+        OperationEquivalence::Flags::IgnoreLocations);
+  };
+  EXPECT_TRUE(hash(op.get()) == hash(reparsed.get()));
 }
 
 } // namespace

``````````

</details>


https://github.com/llvm/llvm-project/pull/76959


More information about the Mlir-commits mailing list