[Mlir-commits] [mlir] [mlir] Add option to ignore commutativity in OperationEquality (PR #181507)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Feb 15 12:33:47 PST 2026


https://github.com/jumerckx updated https://github.com/llvm/llvm-project/pull/181507

>From 1b39f5788bac7fa74aff16c90ce6dd32c27249a4 Mon Sep 17 00:00:00 2001
From: jumerckx <31353884+jumerckx at users.noreply.github.com>
Date: Sat, 14 Feb 2026 17:22:12 +0100
Subject: [PATCH 1/8] fix docstring

---
 mlir/include/mlir/IR/OperationSupport.h | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 1ff7c56ddca38..a05e0ae387164 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1348,8 +1348,8 @@ struct OperationEquivalence {
   /// Helper that can be used with `computeHash` above to ignore operation
   /// operands/result mapping.
   static llvm::hash_code ignoreHashValue(Value) { return llvm::hash_code{}; }
-  /// Helper that can be used with `computeHash` above to ignore operation
-  /// operands/result mapping.
+  /// Helper that can be used with `computeHash` to compute the hash value
+  /// of operands/results directly.
   static llvm::hash_code directHashValue(Value v) { return hash_value(v); }
 
   /// Compare two operations (including their regions) and return if they are

>From 179b5806470f4d301f410de5937f344f5f5cec9f Mon Sep 17 00:00:00 2001
From: jumerckx <31353884+jumerckx at users.noreply.github.com>
Date: Sat, 14 Feb 2026 17:22:41 +0100
Subject: [PATCH 2/8] IgnoreCommutativity flag

---
 mlir/include/mlir/IR/OperationSupport.h | 6 +++++-
 mlir/lib/IR/OperationSupport.cpp        | 5 +++--
 2 files changed, 8 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index a05e0ae387164..3b79f263568dd 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1331,7 +1331,11 @@ struct OperationEquivalence {
     // When provided, the properties attached to the operation are ignored.
     IgnoreProperties = 4,
 
-    LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ IgnoreProperties)
+    // When provided, the commutativity of the operation is ignored, and
+    // operands are compared in an order-sensitive way.
+    IgnoreCommutativity = 8,
+
+    LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ IgnoreCommutativity)
   };
 
   /// Compute a hash for the given operation.
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 2a37f3860fe00..3ff61daaac60b 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -689,7 +689,8 @@ llvm::hash_code OperationEquivalence::computeHash(
     hash = llvm::hash_combine(hash, op->getLoc());
 
   //   - Operands
-  if (op->hasTrait<mlir::OpTrait::IsCommutative>() &&
+  if (!(flags & Flags::IgnoreCommutativity) &&
+      op->hasTrait<mlir::OpTrait::IsCommutative>() &&
       op->getNumOperands() > 0) {
     size_t operandHash = hashOperands(op->getOperand(0));
     for (auto operand : op->getOperands().drop_front())
@@ -854,7 +855,7 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
     return false;
 
   // 2. Compare operands.
-  if (checkCommutativeEquivalent &&
+  if (!(flags & IgnoreCommutativity) && checkCommutativeEquivalent &&
       lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
     auto lhsRange = lhs->getOperands();
     auto rhsRange = rhs->getOperands();

>From a4702e0312b5f39764db1f09bfdea1386d1cdf5f Mon Sep 17 00:00:00 2001
From: jumerckx <31353884+jumerckx at users.noreply.github.com>
Date: Sat, 14 Feb 2026 17:22:51 +0100
Subject: [PATCH 3/8] test IgnoreCommutativity

---
 mlir/test/IR/operation-equality.mlir     | 28 ++++++++++++++++++++++++
 mlir/test/lib/IR/TestOperationEquals.cpp |  2 ++
 2 files changed, 30 insertions(+)

diff --git a/mlir/test/IR/operation-equality.mlir b/mlir/test/IR/operation-equality.mlir
index f382d7d0fbf1b..3138eb2ee41dc 100644
--- a/mlir/test/IR/operation-equality.mlir
+++ b/mlir/test/IR/operation-equality.mlir
@@ -184,3 +184,31 @@
   %0:2 = "test.producer"() : () -> (i32, i32)
   "test.consumer"(%0#1, %0#0) : (i32, i32) -> ()
   }) : () -> ()
+
+// -----
+
+// CHECK-LABEL: test.commutatively_equal
+// CHECK-SAME: compares equals
+
+"test.commutatively_equal"() ({
+  ^bb0(%arg0 : i32, %arg1 : i32):
+    arith.addi %arg0, %arg1 : i32
+  }) : () -> ()
+"test.commutatively_equal"() ({
+  ^bb0(%arg0 : i32, %arg1 : i32):
+    arith.addi %arg1, %arg0 : i32
+  }) : () -> ()
+
+// -----
+
+// CHECK-LABEL: test.ignore_commutatively_equal
+// CHECK-SAME: compares NOT equals
+
+"test.ignore_commutatively_equal"() ({
+  ^bb0(%arg0 : i32, %arg1 : i32):
+    arith.addi %arg0, %arg1 : i32
+  }) { ignore_commutativity } : () -> ()
+"test.ignore_commutatively_equal"() ({
+  ^bb0(%arg0 : i32, %arg1 : i32):
+    arith.addi %arg1, %arg0 : i32
+  }) { ignore_commutativity } : () -> ()
diff --git a/mlir/test/lib/IR/TestOperationEquals.cpp b/mlir/test/lib/IR/TestOperationEquals.cpp
index 03cf5f4facf82..a8ff7752ebb2c 100644
--- a/mlir/test/lib/IR/TestOperationEquals.cpp
+++ b/mlir/test/lib/IR/TestOperationEquals.cpp
@@ -35,6 +35,8 @@ struct TestOperationEqualPass
     OperationEquivalence::Flags flags{};
     if (!first->hasAttr("strict_loc_check"))
       flags |= OperationEquivalence::IgnoreLocations;
+    if (first->hasAttr("ignore_commutativity"))
+      flags |= OperationEquivalence::IgnoreCommutativity;
     if (OperationEquivalence::isEquivalentTo(first, &module.getBody()->back(),
                                              flags))
       llvm::outs() << " compares equals.\n";

>From 3b272e3dbffa0497c710928e053491c7f711b3e4 Mon Sep 17 00:00:00 2001
From: jumerckx <julesmerckx12 at gmail.com>
Date: Sat, 14 Feb 2026 20:07:11 +0100
Subject: [PATCH 4/8] propagate checkCommutativeEquivalent

---
 mlir/lib/IR/OperationSupport.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 3ff61daaac60b..b9641a294fc0c 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -887,9 +887,9 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
 
   // 4. Compare regions.
   for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions()))
-    if (!isRegionEquivalentTo(&std::get<0>(regionPair),
-                              &std::get<1>(regionPair), checkEquivalent,
-                              markEquivalent, flags))
+    if (!isRegionEquivalentTo(
+            &std::get<0>(regionPair), &std::get<1>(regionPair), checkEquivalent,
+            markEquivalent, flags, checkCommutativeEquivalent))
       return false;
 
   return true;

>From 9a22aea8a56c0ab6ea04ec23b8b14edb1cc96199 Mon Sep 17 00:00:00 2001
From: jumerckx <julesmerckx12 at gmail.com>
Date: Sat, 14 Feb 2026 20:07:24 +0100
Subject: [PATCH 5/8] fully check commutative equivalence

---
 mlir/lib/IR/OperationSupport.cpp | 11 ++++++++++-
 1 file changed, 10 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index b9641a294fc0c..0414bb99222b6 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -797,7 +797,16 @@ struct ValueEquivalenceCache {
     };
     auto lhsSorted = sortValues({lhsIt, lhsRange.end()});
     auto rhsSorted = sortValues({rhsIt, rhsRange.end()});
-    return success(lhsSorted == rhsSorted);
+    if (lhsSorted == rhsSorted) {
+      return success();
+    }
+    for (auto operandPair : llvm::zip(lhsSorted, rhsSorted)) {
+      Value lhs = std::get<0>(operandPair);
+      Value rhs = std::get<1>(operandPair);
+      if (failed(checkEquivalent(lhs, rhs)))
+        return failure();
+    }
+    return success();
   }
   void markEquivalent(Value lhsResult, Value rhsResult) {
     auto insertion = equivalentValues.insert({lhsResult, rhsResult});

>From 47e2a1ab9db2d8b39ef746ab750ffaae0e5dd555 Mon Sep 17 00:00:00 2001
From: jumerckx <julesmerckx12 at gmail.com>
Date: Sat, 14 Feb 2026 20:33:25 +0100
Subject: [PATCH 6/8] update test

---
 mlir/test/Dialect/Func/duplicate-function-elimination.mlir | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
index bc04e8fa9cd23..4d00d8a954d17 100644
--- a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
+++ b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
@@ -58,11 +58,10 @@ func.func @user(%arg0: f32, %arg1: f32) -> f32 {
 
 // CHECK:     @add_lr
 // CHECK-NOT: @also_add_lr
-// CHECK:     @add_rl
+// CHECK-NOT:     @add_rl
 // CHECK-NOT: @also_add_rl
 // CHECK:     @user
-// CHECK-2:     call @add_lr
-// CHECK-2:     call @add_rl
+// CHECK-4:     call @add_lr
 
 // -----
 

>From 1bb1a0a5d265bf436c9189db370905d832126c5e Mon Sep 17 00:00:00 2001
From: jumerckx <julesmerckx12 at gmail.com>
Date: Sat, 14 Feb 2026 21:43:28 +0100
Subject: [PATCH 7/8] check full commutativity by using equivalentValues as a
 union-find

---
 mlir/lib/IR/OperationSupport.cpp | 19 +++++++------------
 1 file changed, 7 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 0414bb99222b6..309b7c730f018 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -785,11 +785,12 @@ struct ValueEquivalenceCache {
     if (lhsIt == lhsRange.end())
       return success();
 
-    // Handle another simple case where operands are just a permutation.
-    // Note: This is not sufficient, this handles simple cases relatively
-    // cheaply.
-    auto sortValues = [](ValueRange values) {
-      SmallVector<Value> sortedValues = llvm::to_vector(values);
+    // Replace values with their entry in equivalentValues if they're in there
+    // that way, a sorted pointer comparison is enough to determine
+    // commutativity.
+    auto sortValues = [this](ValueRange values) {
+      SmallVector<Value> sortedValues = llvm::map_to_vector(
+          values, [this](Value a) { return equivalentValues.lookup_or(a, a); });
       llvm::sort(sortedValues, [](Value a, Value b) {
         return a.getAsOpaquePointer() < b.getAsOpaquePointer();
       });
@@ -800,13 +801,7 @@ struct ValueEquivalenceCache {
     if (lhsSorted == rhsSorted) {
       return success();
     }
-    for (auto operandPair : llvm::zip(lhsSorted, rhsSorted)) {
-      Value lhs = std::get<0>(operandPair);
-      Value rhs = std::get<1>(operandPair);
-      if (failed(checkEquivalent(lhs, rhs)))
-        return failure();
-    }
-    return success();
+    return failure();
   }
   void markEquivalent(Value lhsResult, Value rhsResult) {
     auto insertion = equivalentValues.insert({lhsResult, rhsResult});

>From 9c1e9d8727d7e50ad8bcd15ec31a0cb30a5a734c Mon Sep 17 00:00:00 2001
From: jumerckx <julesmerckx12 at gmail.com>
Date: Sun, 15 Feb 2026 21:31:54 +0100
Subject: [PATCH 8/8] Add test for ignore_commutativity that works without
 further changes to equality

---
 mlir/test/IR/operation-equality.mlir     | 30 ++++++++++++++++++++++++
 mlir/test/lib/IR/TestOperationEquals.cpp | 17 ++++++++++----
 2 files changed, 42 insertions(+), 5 deletions(-)

diff --git a/mlir/test/IR/operation-equality.mlir b/mlir/test/IR/operation-equality.mlir
index 3138eb2ee41dc..ed51b610fbc46 100644
--- a/mlir/test/IR/operation-equality.mlir
+++ b/mlir/test/IR/operation-equality.mlir
@@ -187,6 +187,36 @@
 
 // -----
 
+// CHECK-LABEL: test.commutatively_equal_permutation
+// CHECK-SAME: compares equals
+
+builtin.module attributes {test.includes_setup} {
+  %0:2 = "test.producer"() : () -> (i32, i32)
+  "test.commutatively_equal_permutation"() ({
+    arith.addi %0#0, %0#1 : i32
+  }) : () -> ()
+  "test.commutatively_equal_permutation"() ({
+    arith.addi %0#1, %0#0 : i32
+  }) : () -> ()
+}
+
+// -----
+
+// CHECK-LABEL: test.ignore_commutatively_equal_permutation
+// CHECK-SAME: compares NOT equals
+
+builtin.module attributes {test.includes_setup} {
+  %0:2 = "test.producer"() : () -> (i32, i32)
+  "test.ignore_commutatively_equal_permutation"() ({
+    arith.addi %0#0, %0#1 : i32
+  }) { ignore_commutativity } : () -> ()
+  "test.ignore_commutatively_equal_permutation"() ({
+    arith.addi %0#1, %0#0 : i32
+  }) { ignore_commutativity } : () -> ()
+}
+
+// -----
+
 // CHECK-LABEL: test.commutatively_equal
 // CHECK-SAME: compares equals
 
diff --git a/mlir/test/lib/IR/TestOperationEquals.cpp b/mlir/test/lib/IR/TestOperationEquals.cpp
index a8ff7752ebb2c..cf05056ff613e 100644
--- a/mlir/test/lib/IR/TestOperationEquals.cpp
+++ b/mlir/test/lib/IR/TestOperationEquals.cpp
@@ -23,13 +23,20 @@ struct TestOperationEqualPass
     ModuleOp module = getOperation();
     // Expects two operations at the top-level:
     int opCount = module.getBody()->getOperations().size();
-    if (opCount != 2) {
-      module.emitError() << "expected 2 top-level ops in the module, got "
-                         << opCount;
-      return signalPassFailure();
+    if (module->hasAttr("test.includes_setup")) {
+      if (opCount < 2) {
+        module.emitError() << "expected at least 2 top-level ops in the module, got "
+        << opCount;
+        return signalPassFailure();
+      }
+    } else if (opCount != 2) {
+        module.emitError() << "expected 2 top-level ops in the module, got "
+        << opCount;
+        return signalPassFailure();
     }
+    Operation* second = &module.getBody()->back();
+    Operation* first = second->getPrevNode();
 
-    Operation *first = &module.getBody()->front();
     llvm::outs() << first->getName().getStringRef() << " with attr "
                  << first->getDiscardableAttrDictionary();
     OperationEquivalence::Flags flags{};



More information about the Mlir-commits mailing list