[Mlir-commits] [mlir] [mlir] Support full commutative operation equality (PR #192652)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 17 06:34:20 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: jumerckx
<details>
<summary>Changes</summary>
Currently, commutative equality only works if the operand lists are permutations of one another.
By treating the `equivalentValues` map as a map onto a common set of values, I believe full commutative equality can be achieved relatively cheaply.
Does this seem like a good approach?
---
Full diff: https://github.com/llvm/llvm-project/pull/192652.diff
3 Files Affected:
- (modified) mlir/lib/IR/OperationSupport.cpp (+13-9)
- (modified) mlir/test/Dialect/Func/duplicate-function-elimination.mlir (+2-3)
- (modified) mlir/test/IR/operation-equality.mlir (+43)
``````````diff
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 31f71c3949003..47ba299424162 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -791,11 +791,12 @@ struct ValueEquivalenceCache {
if (lhsIt == lhsRange.end())
return success();
- // Handle another simple case where operands are just a permutation.
- // Note: This is not sufficient, this handles simple cases relatively
- // cheaply.
- auto sortValues = [](ValueRange values) {
- SmallVector<Value> sortedValues = llvm::to_vector(values);
+ // Replace values with their entry in equivalentValues if they're in there
+ // that way, a sorted pointer comparison is enough to determine
+ // commutativity.
+ auto sortValues = [this](ValueRange values) {
+ SmallVector<Value> sortedValues = llvm::map_to_vector(
+ values, [this](Value a) { return equivalentValues.lookup_or(a, a); });
llvm::sort(sortedValues, [](Value a, Value b) {
return a.getAsOpaquePointer() < b.getAsOpaquePointer();
});
@@ -803,7 +804,10 @@ struct ValueEquivalenceCache {
};
auto lhsSorted = sortValues({lhsIt, lhsRange.end()});
auto rhsSorted = sortValues({rhsIt, rhsRange.end()});
- return success(lhsSorted == rhsSorted);
+ if (lhsSorted == rhsSorted) {
+ return success();
+ }
+ return failure();
}
void markEquivalent(Value lhsResult, Value rhsResult) {
auto insertion = equivalentValues.insert({lhsResult, rhsResult});
@@ -893,9 +897,9 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
// 4. Compare regions.
for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions()))
- if (!isRegionEquivalentTo(&std::get<0>(regionPair),
- &std::get<1>(regionPair), checkEquivalent,
- markEquivalent, flags))
+ if (!isRegionEquivalentTo(
+ &std::get<0>(regionPair), &std::get<1>(regionPair), checkEquivalent,
+ markEquivalent, flags, checkCommutativeEquivalent))
return false;
return true;
diff --git a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
index bc04e8fa9cd23..4d00d8a954d17 100644
--- a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
+++ b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
@@ -58,11 +58,10 @@ func.func @user(%arg0: f32, %arg1: f32) -> f32 {
// CHECK: @add_lr
// CHECK-NOT: @also_add_lr
-// CHECK: @add_rl
+// CHECK-NOT: @add_rl
// CHECK-NOT: @also_add_rl
// CHECK: @user
-// CHECK-2: call @add_lr
-// CHECK-2: call @add_rl
+// CHECK-4: call @add_lr
// -----
diff --git a/mlir/test/IR/operation-equality.mlir b/mlir/test/IR/operation-equality.mlir
index 7bdbfdd19a62d..ed51b610fbc46 100644
--- a/mlir/test/IR/operation-equality.mlir
+++ b/mlir/test/IR/operation-equality.mlir
@@ -187,6 +187,21 @@
// -----
+// CHECK-LABEL: test.commutatively_equal_permutation
+// CHECK-SAME: compares equals
+
+builtin.module attributes {test.includes_setup} {
+ %0:2 = "test.producer"() : () -> (i32, i32)
+ "test.commutatively_equal_permutation"() ({
+ arith.addi %0#0, %0#1 : i32
+ }) : () -> ()
+ "test.commutatively_equal_permutation"() ({
+ arith.addi %0#1, %0#0 : i32
+ }) : () -> ()
+}
+
+// -----
+
// CHECK-LABEL: test.ignore_commutatively_equal_permutation
// CHECK-SAME: compares NOT equals
@@ -199,3 +214,31 @@ builtin.module attributes {test.includes_setup} {
arith.addi %0#1, %0#0 : i32
}) { ignore_commutativity } : () -> ()
}
+
+// -----
+
+// CHECK-LABEL: test.commutatively_equal
+// CHECK-SAME: compares equals
+
+"test.commutatively_equal"() ({
+ ^bb0(%arg0 : i32, %arg1 : i32):
+ arith.addi %arg0, %arg1 : i32
+ }) : () -> ()
+"test.commutatively_equal"() ({
+ ^bb0(%arg0 : i32, %arg1 : i32):
+ arith.addi %arg1, %arg0 : i32
+ }) : () -> ()
+
+// -----
+
+// CHECK-LABEL: test.ignore_commutatively_equal
+// CHECK-SAME: compares NOT equals
+
+"test.ignore_commutatively_equal"() ({
+ ^bb0(%arg0 : i32, %arg1 : i32):
+ arith.addi %arg0, %arg1 : i32
+ }) { ignore_commutativity } : () -> ()
+"test.ignore_commutatively_equal"() ({
+ ^bb0(%arg0 : i32, %arg1 : i32):
+ arith.addi %arg1, %arg0 : i32
+ }) { ignore_commutativity } : () -> ()
``````````
</details>
https://github.com/llvm/llvm-project/pull/192652
More information about the Mlir-commits
mailing list