[Mlir-commits] [mlir] b336ab4 - [mlir] add a way to query non-property attributes (#76959)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 4 07:40:16 PST 2024


Author: Oleksandr "Alex" Zinenko
Date: 2024-01-04T16:40:13+01:00
New Revision: b336ab42dcc81a351b2f875f28c70b74d8814611

URL: https://github.com/llvm/llvm-project/commit/b336ab42dcc81a351b2f875f28c70b74d8814611
DIFF: https://github.com/llvm/llvm-project/commit/b336ab42dcc81a351b2f875f28c70b74d8814611.diff

LOG: [mlir] add a way to query non-property attributes (#76959)

This helps support generic manipulation of operations that don't (yet)
use properties to store inherent attributes.

Use this mechanism in type inference and operation equivalence.

Note that only minimal unit tests are introduced as all the upstream
dialects seem to have been updated to use properties and the
non-property behavior is essentially deprecated and untested.

Added: 
    

Modified: 
    mlir/include/mlir/IR/Operation.h
    mlir/lib/IR/OperationSupport.cpp
    mlir/lib/Interfaces/InferTypeOpInterface.cpp
    mlir/unittests/Bytecode/BytecodeTest.cpp
    mlir/unittests/IR/OpPropertiesTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index d2f52cf1afee81..3ffd3517fe5a66 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -500,6 +500,9 @@ class alignas(8) Operation final
                                llvm::to_vector(getDiscardableAttrs()));
   }
 
+  /// Return all attributes that are not stored as properties.
+  DictionaryAttr getRawDictionaryAttrs() { return attrs; }
+
   /// Return all of the attributes on this operation.
   ArrayRef<NamedAttribute> getAttrs() { return getAttrDictionary().getValue(); }
 

diff  --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index eaad6f2891608f..e10cd748e03ba5 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -675,7 +675,7 @@ llvm::hash_code OperationEquivalence::computeHash(
   //   - Attributes
   //   - Result Types
   llvm::hash_code hash =
-      llvm::hash_combine(op->getName(), op->getDiscardableAttrDictionary(),
+      llvm::hash_combine(op->getName(), op->getRawDictionaryAttrs(),
                          op->getResultTypes(), op->hashProperties());
 
   //   - Location if required
@@ -831,14 +831,13 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
 
   // 1. Compare the operation properties.
   if (lhs->getName() != rhs->getName() ||
-      lhs->getDiscardableAttrDictionary() !=
-          rhs->getDiscardableAttrDictionary() ||
+      lhs->getRawDictionaryAttrs() != rhs->getRawDictionaryAttrs() ||
       lhs->getNumRegions() != rhs->getNumRegions() ||
       lhs->getNumSuccessors() != rhs->getNumSuccessors() ||
       lhs->getNumOperands() != rhs->getNumOperands() ||
       lhs->getNumResults() != rhs->getNumResults() ||
       !lhs->getName().compareOpProperties(lhs->getPropertiesStorage(),
-                                        rhs->getPropertiesStorage()))
+                                          rhs->getPropertiesStorage()))
     return false;
   if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
     return false;
@@ -923,7 +922,7 @@ OperationFingerPrint::OperationFingerPrint(Operation *topOp) {
     if (op != topOp)
       addDataToHash(hasher, op->getParentOp());
     //   - Attributes
-    addDataToHash(hasher, op->getDiscardableAttrDictionary());
+    addDataToHash(hasher, op->getRawDictionaryAttrs());
     //   - Properties
     addDataToHash(hasher, op->hashProperties());
     //   - Blocks in Regions

diff  --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index ee4c0519b9f54f..e52d0e17cda22b 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -240,9 +240,8 @@ LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
   auto retTypeFn = cast<InferTypeOpInterface>(op);
   auto result = retTypeFn.refineReturnTypes(
       op->getContext(), op->getLoc(), op->getOperands(),
-      op->getPropertiesStorage() ? op->getDiscardableAttrDictionary()
-                                 : op->getAttrDictionary(),
-      op->getPropertiesStorage(), op->getRegions(), inferredReturnTypes);
+      op->getRawDictionaryAttrs(), op->getPropertiesStorage(), op->getRegions(),
+      inferredReturnTypes);
   if (failed(result))
     op->emitOpError() << "failed to infer returned types";
 

diff  --git a/mlir/unittests/Bytecode/BytecodeTest.cpp b/mlir/unittests/Bytecode/BytecodeTest.cpp
index 76ff1a8194db74..bb7241c2d51969 100644
--- a/mlir/unittests/Bytecode/BytecodeTest.cpp
+++ b/mlir/unittests/Bytecode/BytecodeTest.cpp
@@ -144,4 +144,7 @@ TEST(Bytecode, OpWithoutProperties) {
   EXPECT_EQ(roundtripped->getAttrs().size(), 2u);
   EXPECT_TRUE(roundtripped->getInherentAttr("inherent_attr") != std::nullopt);
   EXPECT_TRUE(roundtripped->getDiscardableAttr("other_attr") != Attribute());
+
+  EXPECT_TRUE(OperationEquivalence::computeHash(op.get()) ==
+              OperationEquivalence::computeHash(roundtripped));
 }

diff  --git a/mlir/unittests/IR/OpPropertiesTest.cpp b/mlir/unittests/IR/OpPropertiesTest.cpp
index bb1b741d1cc223..365775d541ec3d 100644
--- a/mlir/unittests/IR/OpPropertiesTest.cpp
+++ b/mlir/unittests/IR/OpPropertiesTest.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OperationSupport.h"
 #include "mlir/Parser/Parser.h"
 #include "gtest/gtest.h"
 #include <optional>
@@ -401,6 +402,15 @@ TEST(OpPropertiesTest, withoutPropertiesDiscardableAttrs) {
   op->print(os);
   EXPECT_TRUE(StringRef(os.str()).contains("inherent_attr = 42"));
   EXPECT_TRUE(StringRef(os.str()).contains("other_attr = 56"));
+
+  OwningOpRef<Operation *> reparsed = parseSourceString(os.str(), config);
+  auto trivialHash = [](Value v) { return hash_value(v); };
+  auto hash = [&](Operation *operation) {
+    return OperationEquivalence::computeHash(
+        operation, trivialHash, trivialHash,
+        OperationEquivalence::Flags::IgnoreLocations);
+  };
+  EXPECT_TRUE(hash(op.get()) == hash(reparsed.get()));
 }
 
 } // namespace


        


More information about the Mlir-commits mailing list