[Mlir-commits] [mlir] 2109587 - [MLIR] Don't sort operand of commutative ops when comparing two ops as there is a correctness issue

Jacques Pienaar llvmlistbot at llvm.org
Fri Jul 14 16:12:12 PDT 2023


Author: tomnatan
Date: 2023-07-14T16:11:54-07:00
New Revision: 2109587cee341329c49fa999725ce6a486621b37

URL: https://github.com/llvm/llvm-project/commit/2109587cee341329c49fa999725ce6a486621b37
DIFF: https://github.com/llvm/llvm-project/commit/2109587cee341329c49fa999725ce6a486621b37.diff

LOG: [MLIR] Don't sort operand of commutative ops when comparing two ops as there is a correctness issue

This feature was introduced in `D123492`.

Doing equivalence on pointers to sort operands of commutative operations is incorrect when checking equivalence of ops in separate regions (where the lhs and rhs operands are marked as equivalent but are not the same value).

It was also discussed in `D123492` and `D129480` that the correct solution would be to stable sort the operands in canonicalization (based on some numbering in the region maybe), but until that lands, reverting this change will unblock us and other users.

An example of a pass that might not work properly because of this is `DuplicateFunctionEliminationPass`.

Reviewed By: mehdi_amini, jpienaar

Differential Revision: https://reviews.llvm.org/D154699

Added: 
    

Modified: 
    flang/test/Fir/commute.fir
    mlir/lib/IR/OperationSupport.cpp
    mlir/test/Dialect/Func/duplicate-function-elimination.mlir
    mlir/test/Transforms/cse.mlir

Removed: 
    


################################################################################
diff  --git a/flang/test/Fir/commute.fir b/flang/test/Fir/commute.fir
index be25fd61953deb..b39fe48145b3f3 100644
--- a/flang/test/Fir/commute.fir
+++ b/flang/test/Fir/commute.fir
@@ -1,4 +1,7 @@
 // RUN: fir-opt %s | tco | FileCheck %s
+//
+// XFAIL:*
+// See: https://github.com/llvm/llvm-project/issues/63784
 
 // CHECK-LABEL: define i32 @f1(i32 %0, i32 %1)
 func.func @f1(%a : i32, %b : i32) -> i32 {

diff  --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 028d3f2639cf08..79cc38da051ee1 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -661,19 +661,10 @@ llvm::hash_code OperationEquivalence::computeHash(
     hash = llvm::hash_combine(hash, op->getLoc());
 
   //   - Operands
-  ValueRange operands = op->getOperands();
-  SmallVector<Value> operandStorage;
-  if (op->hasTrait<mlir::OpTrait::IsCommutative>()) {
-    operandStorage.append(operands.begin(), operands.end());
-    llvm::sort(operandStorage, [](Value a, Value b) -> bool {
-      return a.getAsOpaquePointer() < b.getAsOpaquePointer();
-    });
-    operands = operandStorage;
-  }
-  for (Value operand : operands)
+  for (Value operand : op->getOperands())
     hash = llvm::hash_combine(hash, hashOperands(operand));
 
-  //   - Operands
+  //   - Results
   for (Value result : op->getResults())
     hash = llvm::hash_combine(hash, hashResults(result));
   return hash;
@@ -784,41 +775,7 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs,
     return false;
 
   // 2. Compare operands.
-  ValueRange lhsOperands = lhs->getOperands(), rhsOperands = rhs->getOperands();
-  SmallVector<Value> lhsOperandStorage, rhsOperandStorage;
-  if (lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
-    auto sortValues = [](ValueRange values) {
-      SmallVector<Value> sortedValues = llvm::to_vector(values);
-      llvm::sort(sortedValues, [](Value a, Value b) {
-        auto aArg = llvm::dyn_cast<BlockArgument>(a);
-        auto bArg = llvm::dyn_cast<BlockArgument>(b);
-
-        // Case 1. Both `a` and `b` are `BlockArgument`s.
-        if (aArg && bArg) {
-          if (aArg.getParentBlock() == bArg.getParentBlock())
-            return aArg.getArgNumber() < bArg.getArgNumber();
-          return aArg.getParentBlock() < bArg.getParentBlock();
-        }
-
-        // Case 2. One of then is a `BlockArgument` and other is not. Treat
-        // `BlockArgument` as lesser.
-        if (aArg && !bArg)
-          return true;
-        if (bArg && !aArg)
-          return false;
-
-        // Case 3. Both are values.
-        return a.getAsOpaquePointer() < b.getAsOpaquePointer();
-      });
-      return sortedValues;
-    };
-    lhsOperandStorage = sortValues(lhsOperands);
-    lhsOperands = lhsOperandStorage;
-    rhsOperandStorage = sortValues(rhsOperands);
-    rhsOperands = rhsOperandStorage;
-  }
-
-  for (auto operandPair : llvm::zip(lhsOperands, rhsOperands)) {
+  for (auto operandPair : llvm::zip(lhs->getOperands(), rhs->getOperands())) {
     Value curArg = std::get<0>(operandPair);
     Value otherArg = std::get<1>(operandPair);
     if (curArg == otherArg)

diff  --git a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
index acf2bfb97cdb93..28d059a149bde8 100644
--- a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
+++ b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir
@@ -58,10 +58,11 @@ func.func @user(%arg0: f32, %arg1: f32) -> f32 {
 
 // CHECK:     @add_lr
 // CHECK-NOT: @also_add_lr
-// CHECK-NOT: @add_rl
+// CHECK:     @add_rl
 // CHECK-NOT: @also_add_rl
 // CHECK:     @user
-// CHECK-4:     call @add_lr
+// CHECK-2:     call @add_lr
+// CHECK-2:     call @add_rl
 
 // -----
 
@@ -108,7 +109,7 @@ func.func @user(%pred : i1, %arg0: f32, %arg1: f32) -> f32 {
 
 // -----
 
-func.func @deep_tree(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %even: f32, %odd: f32) 
+func.func @deep_tree(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %even: f32, %odd: f32)
     -> f32 {
   %0 = scf.if %p0 -> f32 {
     %1 = scf.if %p1 -> f32 {
@@ -188,7 +189,7 @@ func.func @deep_tree(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %even: f32, %odd: f32)
   return %0 : f32
 }
 
-func.func @also_deep_tree(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %even: f32, 
+func.func @also_deep_tree(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %even: f32,
     %odd: f32) -> f32 {
   %0 = scf.if %p0 -> f32 {
     %1 = scf.if %p1 -> f32 {
@@ -268,7 +269,7 @@ func.func @also_deep_tree(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %even: f32,
   return %0 : f32
 }
 
-func.func @reverse_deep_tree(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %even: f32, 
+func.func @reverse_deep_tree(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %even: f32,
     %odd: f32) -> f32 {
   %0 = scf.if %p0 -> f32 {
     %1 = scf.if %p1 -> f32 {
@@ -348,13 +349,13 @@ func.func @reverse_deep_tree(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %even: f32,
   return %0 : f32
 }
 
-func.func @user(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %odd: f32, %even: f32) 
+func.func @user(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %odd: f32, %even: f32)
     -> (f32, f32, f32) {
-  %0 = call @deep_tree(%p0, %p1, %p2, %p3, %odd, %even) 
+  %0 = call @deep_tree(%p0, %p1, %p2, %p3, %odd, %even)
       : (i1, i1, i1, i1, f32, f32) -> f32
-  %1 = call @also_deep_tree(%p0, %p1, %p2, %p3, %odd, %even) 
+  %1 = call @also_deep_tree(%p0, %p1, %p2, %p3, %odd, %even)
       : (i1, i1, i1, i1, f32, f32) -> f32
-  %2 = call @reverse_deep_tree(%p0, %p1, %p2, %p3, %odd, %even) 
+  %2 = call @reverse_deep_tree(%p0, %p1, %p2, %p3, %odd, %even)
       : (i1, i1, i1, i1, f32, f32) -> f32
   return %0, %1, %2 : f32, f32, f32
 }

diff  --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir
index 7086f5f462f592..f3a820f8a765be 100644
--- a/mlir/test/Transforms/cse.mlir
+++ b/mlir/test/Transforms/cse.mlir
@@ -311,18 +311,6 @@ func.func @dont_remove_duplicated_read_op_with_sideeffecting() -> i32 {
   return %2 : i32
 }
 
-/// This test is checking that identical commutative operation are gracefully
-/// handled but the CSE pass.
-// CHECK-LABEL: func @check_cummutative_cse
-func.func @check_cummutative_cse(%a : i32, %b : i32) -> i32 {
-  // CHECK: %[[ADD1:.*]] = arith.addi %{{.*}}, %{{.*}} : i32
-  %1 = arith.addi %a, %b : i32
-  %2 = arith.addi %b, %a : i32
-  // CHECK-NEXT:  arith.muli %[[ADD1]], %[[ADD1]] : i32
-  %3 = arith.muli %1, %2 : i32
-  return %3 : i32
-}
-
 // Check that an operation with a single region can CSE.
 func.func @cse_single_block_ops(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
   -> (tensor<?x?xf32>, tensor<?x?xf32>) {
@@ -425,31 +413,9 @@ func.func @no_cse_single_block_ops_
diff erent_bodies(%a : tensor<?x?xf32>, %b : t
 //       CHECK:   %[[OP1:.+]] = test.cse_of_single_block_op
 //       CHECK:   return %[[OP0]], %[[OP1]]
 
-// Account for commutative ops within regions during CSE.
-func.func @cse_single_block_with_commutative_ops(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : f32)
-  -> (tensor<?x?xf32>, tensor<?x?xf32>) {
-  %0 = test.cse_of_single_block_op inputs(%a, %b) {
-    ^bb0(%arg0 : f32, %arg1 : f32):
-    %1 = arith.addf %arg0, %arg1 : f32
-    %2 = arith.mulf %1, %c : f32
-    test.region_yield %2 : f32
-  } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
-  %1 = test.cse_of_single_block_op inputs(%a, %b) {
-    ^bb0(%arg0 : f32, %arg1 : f32):
-    %1 = arith.addf %arg1, %arg0 : f32
-    %2 = arith.mulf %c, %1 : f32
-    test.region_yield %2 : f32
-  } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
-  return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
-}
-// CHECK-LABEL: func @cse_single_block_with_commutative_ops
-//       CHECK:   %[[OP:.+]] = test.cse_of_single_block_op
-//   CHECK-NOT:   test.cse_of_single_block_op
-//       CHECK:   return %[[OP]], %[[OP]]
-
 func.func @failing_issue_59135(%arg0: tensor<2x2xi1>, %arg1: f32, %arg2 : tensor<2xi1>) -> (tensor<2xi1>, tensor<2xi1>) {
-  %false_2 = arith.constant false 
-  %true_5 = arith.constant true 
+  %false_2 = arith.constant false
+  %true_5 = arith.constant true
   %9 = test.cse_of_single_block_op inputs(%arg2) {
   ^bb0(%out: i1):
     %true_144 = arith.constant true


        


More information about the Mlir-commits mailing list