[Mlir-commits] [mlir] bd51496 - [mlir][CSE] Add ability to remove commutative operations
Valentin Clement
llvmlistbot at llvm.org
Sat Apr 16 12:10:01 PDT 2022
Author: Valentin Clement
Date: 2022-04-16T21:09:47+02:00
New Revision: bd514967aa221bef5d1adaec12abc68511f325f0
URL: https://github.com/llvm/llvm-project/commit/bd514967aa221bef5d1adaec12abc68511f325f0
DIFF: https://github.com/llvm/llvm-project/commit/bd514967aa221bef5d1adaec12abc68511f325f0.diff
LOG: [mlir][CSE] Add ability to remove commutative operations
This patch takes advantage of the Commutative trait on operation
to remove identical commutative operations where the operands are swapped.
The second operation below can be removed since `arith.addi` is commutative.
```
%1 = arith.addi %a, %b : i32
%2 = arith.addi %b, %a : i32
```
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D123492
Added:
Modified:
mlir/lib/IR/OperationSupport.cpp
mlir/test/Transforms/cse.mlir
Removed:
################################################################################
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 012d980955289..aaae4c8ad9da1 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -633,8 +633,18 @@ llvm::hash_code OperationEquivalence::computeHash(
op->getName(), op->getAttrDictionary(), op->getResultTypes());
// - Operands
- for (Value operand : op->getOperands())
+ ValueRange operands = op->getOperands();
+ SmallVector<Value> operandStorage;
+ if (op->hasTrait<mlir::OpTrait::IsCommutative>()) {
+ operandStorage.append(operands.begin(), operands.end());
+ llvm::sort(operandStorage, [](Value a, Value b) -> bool {
+ return a.getAsOpaquePointer() < b.getAsOpaquePointer();
+ });
+ operands = operandStorage;
+ }
+ for (Value operand : operands)
hash = llvm::hash_combine(hash, hashOperands(operand));
+
// - Operands
for (Value result : op->getResults())
hash = llvm::hash_combine(hash, hashResults(result));
@@ -710,6 +720,21 @@ bool OperationEquivalence::isEquivalentTo(
if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
return false;
+ ValueRange lhsOperands = lhs->getOperands(), rhsOperands = rhs->getOperands();
+ SmallVector<Value> lhsOperandStorage, rhsOperandStorage;
+ if (lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
+ lhsOperandStorage.append(lhsOperands.begin(), lhsOperands.end());
+ llvm::sort(lhsOperandStorage, [](Value a, Value b) -> bool {
+ return a.getAsOpaquePointer() < b.getAsOpaquePointer();
+ });
+ lhsOperands = lhsOperandStorage;
+
+ rhsOperandStorage.append(rhsOperands.begin(), rhsOperands.end());
+ llvm::sort(rhsOperandStorage, [](Value a, Value b) -> bool {
+ return a.getAsOpaquePointer() < b.getAsOpaquePointer();
+ });
+ rhsOperands = rhsOperandStorage;
+ }
auto checkValueRangeMapping =
[](ValueRange lhs, ValueRange rhs,
function_ref<LogicalResult(Value, Value)> mapValues) {
@@ -724,8 +749,7 @@ bool OperationEquivalence::isEquivalentTo(
return true;
};
// Check mapping of operands and results.
- if (!checkValueRangeMapping(lhs->getOperands(), rhs->getOperands(),
- mapOperands))
+ if (!checkValueRangeMapping(lhsOperands, rhsOperands, mapOperands))
return false;
if (!checkValueRangeMapping(lhs->getResults(), rhs->getResults(), mapResults))
return false;
diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir
index 189cdde18397e..216218e26dbe1 100644
--- a/mlir/test/Transforms/cse.mlir
+++ b/mlir/test/Transforms/cse.mlir
@@ -310,3 +310,15 @@ func @dont_remove_duplicated_read_op_with_sideeffecting() -> i32 {
%2 = arith.addi %0, %1 : i32
return %2 : i32
}
+
+/// This test is checking that identical commutative operation are gracefully
+/// handled but the CSE pass.
+// CHECK-LABEL: func @check_cummutative_cse
+func @check_cummutative_cse(%a : i32, %b : i32) -> i32 {
+ // CHECK: %[[ADD1:.*]] = arith.addi %{{.*}}, %{{.*}} : i32
+ %1 = arith.addi %a, %b : i32
+ %2 = arith.addi %b, %a : i32
+ // CHECK-NEXT: arith.muli %[[ADD1]], %[[ADD1]] : i32
+ %3 = arith.muli %1, %2 : i32
+ return %3 : i32
+}
More information about the Mlir-commits
mailing list