[Mlir-commits] [mlir] bd51496 - [mlir][CSE] Add ability to remove commutative operations

Valentin Clement llvmlistbot at llvm.org
Sat Apr 16 12:10:01 PDT 2022


Author: Valentin Clement
Date: 2022-04-16T21:09:47+02:00
New Revision: bd514967aa221bef5d1adaec12abc68511f325f0

URL: https://github.com/llvm/llvm-project/commit/bd514967aa221bef5d1adaec12abc68511f325f0
DIFF: https://github.com/llvm/llvm-project/commit/bd514967aa221bef5d1adaec12abc68511f325f0.diff

LOG: [mlir][CSE] Add ability to remove commutative operations

This patch takes advantage of the Commutative trait on operation
to remove identical commutative operations where the operands are swapped.

The second operation below can be removed since `arith.addi` is commutative.
```
%1 = arith.addi %a, %b : i32
%2 = arith.addi %b, %a : i32
```

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D123492

Added: 
    

Modified: 
    mlir/lib/IR/OperationSupport.cpp
    mlir/test/Transforms/cse.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 012d980955289..aaae4c8ad9da1 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -633,8 +633,18 @@ llvm::hash_code OperationEquivalence::computeHash(
       op->getName(), op->getAttrDictionary(), op->getResultTypes());
 
   //   - Operands
-  for (Value operand : op->getOperands())
+  ValueRange operands = op->getOperands();
+  SmallVector<Value> operandStorage;
+  if (op->hasTrait<mlir::OpTrait::IsCommutative>()) {
+    operandStorage.append(operands.begin(), operands.end());
+    llvm::sort(operandStorage, [](Value a, Value b) -> bool {
+      return a.getAsOpaquePointer() < b.getAsOpaquePointer();
+    });
+    operands = operandStorage;
+  }
+  for (Value operand : operands)
     hash = llvm::hash_combine(hash, hashOperands(operand));
+
   //   - Operands
   for (Value result : op->getResults())
     hash = llvm::hash_combine(hash, hashResults(result));
@@ -710,6 +720,21 @@ bool OperationEquivalence::isEquivalentTo(
   if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
     return false;
 
+  ValueRange lhsOperands = lhs->getOperands(), rhsOperands = rhs->getOperands();
+  SmallVector<Value> lhsOperandStorage, rhsOperandStorage;
+  if (lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
+    lhsOperandStorage.append(lhsOperands.begin(), lhsOperands.end());
+    llvm::sort(lhsOperandStorage, [](Value a, Value b) -> bool {
+      return a.getAsOpaquePointer() < b.getAsOpaquePointer();
+    });
+    lhsOperands = lhsOperandStorage;
+
+    rhsOperandStorage.append(rhsOperands.begin(), rhsOperands.end());
+    llvm::sort(rhsOperandStorage, [](Value a, Value b) -> bool {
+      return a.getAsOpaquePointer() < b.getAsOpaquePointer();
+    });
+    rhsOperands = rhsOperandStorage;
+  }
   auto checkValueRangeMapping =
       [](ValueRange lhs, ValueRange rhs,
          function_ref<LogicalResult(Value, Value)> mapValues) {
@@ -724,8 +749,7 @@ bool OperationEquivalence::isEquivalentTo(
         return true;
       };
   // Check mapping of operands and results.
-  if (!checkValueRangeMapping(lhs->getOperands(), rhs->getOperands(),
-                              mapOperands))
+  if (!checkValueRangeMapping(lhsOperands, rhsOperands, mapOperands))
     return false;
   if (!checkValueRangeMapping(lhs->getResults(), rhs->getResults(), mapResults))
     return false;

diff  --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir
index 189cdde18397e..216218e26dbe1 100644
--- a/mlir/test/Transforms/cse.mlir
+++ b/mlir/test/Transforms/cse.mlir
@@ -310,3 +310,15 @@ func @dont_remove_duplicated_read_op_with_sideeffecting() -> i32 {
   %2 = arith.addi %0, %1 : i32
   return %2 : i32
 }
+
+/// This test is checking that identical commutative operation are gracefully
+/// handled but the CSE pass.
+// CHECK-LABEL: func @check_cummutative_cse
+func @check_cummutative_cse(%a : i32, %b : i32) -> i32 {
+  // CHECK: %[[ADD1:.*]] = arith.addi %{{.*}}, %{{.*}} : i32
+  %1 = arith.addi %a, %b : i32
+  %2 = arith.addi %b, %a : i32
+  // CHECK-NEXT:  arith.muli %[[ADD1]], %[[ADD1]] : i32
+  %3 = arith.muli %1, %2 : i32
+  return %3 : i32
+}


        


More information about the Mlir-commits mailing list