[Mlir-commits] [mlir] [mlir] Support full commutative operation equality (PR #192652)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 17 06:48:32 PDT 2026


https://github.com/jumerckx updated https://github.com/llvm/llvm-project/pull/192652

>From 3ca34e203402d30ba6134cfa8317d9c03499d03f Mon Sep 17 00:00:00 2001
From: jumerckx <julesmerckx12 at gmail.com>
Date: Sun, 15 Feb 2026 22:57:55 +0100
Subject: [PATCH 1/7] test with commutativity

---
 mlir/test/IR/operation-equality.mlir | 15 +++++++++++++++
 1 file changed, 15 insertions(+)

diff --git a/mlir/test/IR/operation-equality.mlir b/mlir/test/IR/operation-equality.mlir
index 7bdbfdd19a62d..405701aafb127 100644
--- a/mlir/test/IR/operation-equality.mlir
+++ b/mlir/test/IR/operation-equality.mlir
@@ -187,6 +187,21 @@
 
 // -----
 
+// CHECK-LABEL: test.commutatively_equal_permutation
+// CHECK-SAME: compares equals
+
+builtin.module attributes {test.includes_setup} {
+  %0:2 = "test.producer"() : () -> (i32, i32)
+  "test.commutatively_equal_permutation"() ({
+    arith.addi %0#0, %0#1 : i32
+  }) : () -> ()
+  "test.commutatively_equal_permutation"() ({
+    arith.addi %0#1, %0#0 : i32
+  }) : () -> ()
+}
+
+// -----
+
 // CHECK-LABEL: test.ignore_commutatively_equal_permutation
 // CHECK-SAME: compares NOT equals
 

>From c93a533c09efdfa27d8c57b4431b8ecb72015e51 Mon Sep 17 00:00:00 2001
From: jumerckx <julesmerckx12 at gmail.com>
Date: Sat, 14 Feb 2026 20:07:11 +0100
Subject: [PATCH 2/7] propagate checkCommutativeEquivalent

---
 mlir/lib/IR/OperationSupport.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 31f71c3949003..91c62a6604c16 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -893,9 +893,9 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
 
   // 4. Compare regions.
   for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions()))
-    if (!isRegionEquivalentTo(&std::get<0>(regionPair),
-                              &std::get<1>(regionPair), checkEquivalent,
-                              markEquivalent, flags))
+    if (!isRegionEquivalentTo(
+            &std::get<0>(regionPair), &std::get<1>(regionPair), checkEquivalent,
+            markEquivalent, flags, checkCommutativeEquivalent))
       return false;
 
   return true;

>From 0a94d5b65910fd5afba58452aeb2a1b12d48003a Mon Sep 17 00:00:00 2001
From: jumerckx <julesmerckx12 at gmail.com>
Date: Sat, 14 Feb 2026 20:07:24 +0100
Subject: [PATCH 3/7] fully check commutative equivalence

---
 mlir/lib/IR/OperationSupport.cpp | 11 ++++++++++-
 1 file changed, 10 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 91c62a6604c16..dfbdb441caebd 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -803,7 +803,16 @@ struct ValueEquivalenceCache {
     };
     auto lhsSorted = sortValues({lhsIt, lhsRange.end()});
     auto rhsSorted = sortValues({rhsIt, rhsRange.end()});
-    return success(lhsSorted == rhsSorted);
+    if (lhsSorted == rhsSorted) {
+      return success();
+    }
+    for (auto operandPair : llvm::zip(lhsSorted, rhsSorted)) {
+      Value lhs = std::get<0>(operandPair);
+      Value rhs = std::get<1>(operandPair);
+      if (failed(checkEquivalent(lhs, rhs)))
+        return failure();
+    }
+    return success();
   }
   void markEquivalent(Value lhsResult, Value rhsResult) {
     auto insertion = equivalentValues.insert({lhsResult, rhsResult});

>From ee8745a44b76d23a3dc57e57c523e4df23d9258f Mon Sep 17 00:00:00 2001
From: jumerckx <julesmerckx12 at gmail.com>
Date: Sat, 14 Feb 2026 21:43:28 +0100
Subject: [PATCH 4/7] check full commutativity by using equivalentValues as a
 union-find

---
 mlir/lib/IR/OperationSupport.cpp | 19 +++++++------------
 1 file changed, 7 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index dfbdb441caebd..47ba299424162 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -791,11 +791,12 @@ struct ValueEquivalenceCache {
     if (lhsIt == lhsRange.end())
       return success();
 
-    // Handle another simple case where operands are just a permutation.
-    // Note: This is not sufficient, this handles simple cases relatively
-    // cheaply.
-    auto sortValues = [](ValueRange values) {
-      SmallVector<Value> sortedValues = llvm::to_vector(values);
+    // Replace values with their entry in equivalentValues if they're in there
+    // that way, a sorted pointer comparison is enough to determine
+    // commutativity.
+    auto sortValues = [this](ValueRange values) {
+      SmallVector<Value> sortedValues = llvm::map_to_vector(
+          values, [this](Value a) { return equivalentValues.lookup_or(a, a); });
       llvm::sort(sortedValues, [](Value a, Value b) {
         return a.getAsOpaquePointer() < b.getAsOpaquePointer();
       });
@@ -806,13 +807,7 @@ struct ValueEquivalenceCache {
     if (lhsSorted == rhsSorted) {
       return success();
     }
-    for (auto operandPair : llvm::zip(lhsSorted, rhsSorted)) {
-      Value lhs = std::get<0>(operandPair);
-      Value rhs = std::get<1>(operandPair);
-      if (failed(checkEquivalent(lhs, rhs)))
-        return failure();
-    }
-    return success();
+    return failure();
   }
   void markEquivalent(Value lhsResult, Value rhsResult) {
     auto insertion = equivalentValues.insert({lhsResult, rhsResult});

>From 15fc681f09e09b4759b66cbc1e3a675b8d1bf121 Mon Sep 17 00:00:00 2001
From: jumerckx <julesmerckx12 at gmail.com>
Date: Sat, 14 Feb 2026 20:33:25 +0100
Subject: [PATCH 5/7] update test

---
 mlir/test/Dialect/Func/duplicate-function-elimination.mlir | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
index bc04e8fa9cd23..4d00d8a954d17 100644
--- a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
+++ b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
@@ -58,11 +58,10 @@ func.func @user(%arg0: f32, %arg1: f32) -> f32 {
 
 // CHECK:     @add_lr
 // CHECK-NOT: @also_add_lr
-// CHECK:     @add_rl
+// CHECK-NOT:     @add_rl
 // CHECK-NOT: @also_add_rl
 // CHECK:     @user
-// CHECK-2:     call @add_lr
-// CHECK-2:     call @add_rl
+// CHECK-4:     call @add_lr
 
 // -----
 

>From e89ae96c8fe3263c1ebd06a71065d813af78faa5 Mon Sep 17 00:00:00 2001
From: jumerckx <31353884+jumerckx at users.noreply.github.com>
Date: Sat, 14 Feb 2026 17:22:51 +0100
Subject: [PATCH 6/7] test full commutative equality

---
 mlir/test/IR/operation-equality.mlir | 28 ++++++++++++++++++++++++++++
 1 file changed, 28 insertions(+)

diff --git a/mlir/test/IR/operation-equality.mlir b/mlir/test/IR/operation-equality.mlir
index 405701aafb127..ed51b610fbc46 100644
--- a/mlir/test/IR/operation-equality.mlir
+++ b/mlir/test/IR/operation-equality.mlir
@@ -214,3 +214,31 @@ builtin.module attributes {test.includes_setup} {
     arith.addi %0#1, %0#0 : i32
   }) { ignore_commutativity } : () -> ()
 }
+
+// -----
+
+// CHECK-LABEL: test.commutatively_equal
+// CHECK-SAME: compares equals
+
+"test.commutatively_equal"() ({
+  ^bb0(%arg0 : i32, %arg1 : i32):
+    arith.addi %arg0, %arg1 : i32
+  }) : () -> ()
+"test.commutatively_equal"() ({
+  ^bb0(%arg0 : i32, %arg1 : i32):
+    arith.addi %arg1, %arg0 : i32
+  }) : () -> ()
+
+// -----
+
+// CHECK-LABEL: test.ignore_commutatively_equal
+// CHECK-SAME: compares NOT equals
+
+"test.ignore_commutatively_equal"() ({
+  ^bb0(%arg0 : i32, %arg1 : i32):
+    arith.addi %arg0, %arg1 : i32
+  }) { ignore_commutativity } : () -> ()
+"test.ignore_commutatively_equal"() ({
+  ^bb0(%arg0 : i32, %arg1 : i32):
+    arith.addi %arg1, %arg0 : i32
+  }) { ignore_commutativity } : () -> ()

>From 4832079b48ae17c8fc4c8fb7817097c328106964 Mon Sep 17 00:00:00 2001
From: jumerckx <31353884+jumerckx at users.noreply.github.com>
Date: Fri, 17 Apr 2026 15:47:46 +0200
Subject: [PATCH 7/7] revert return expression

---
 mlir/lib/IR/OperationSupport.cpp | 5 +----
 1 file changed, 1 insertion(+), 4 deletions(-)

diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 47ba299424162..32c6426429ae8 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -804,10 +804,7 @@ struct ValueEquivalenceCache {
     };
     auto lhsSorted = sortValues({lhsIt, lhsRange.end()});
     auto rhsSorted = sortValues({rhsIt, rhsRange.end()});
-    if (lhsSorted == rhsSorted) {
-      return success();
-    }
-    return failure();
+    return success(lhsSorted == rhsSorted);
   }
   void markEquivalent(Value lhsResult, Value rhsResult) {
     auto insertion = equivalentValues.insert({lhsResult, rhsResult});



More information about the Mlir-commits mailing list