[Mlir-commits] [mlir] 7219b31 - [mlir] Additional folding for SelectOp

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Mar 20 03:42:24 PDT 2021


Author: Butygin
Date: 2021-03-20T13:40:42+03:00
New Revision: 7219b31d40f14604c669d633b014d0cc8b707cf3

URL: https://github.com/llvm/llvm-project/commit/7219b31d40f14604c669d633b014d0cc8b707cf3
DIFF: https://github.com/llvm/llvm-project/commit/7219b31d40f14604c669d633b014d0cc8b707cf3.diff

LOG: [mlir] Additional folding for SelectOp

* Fold SelectOp when both true and false args are same SSA value
* Fold some cmp + select patterns

Differential Revision: https://reviews.llvm.org/D98576

Added: 
    

Modified: 
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Standard/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index bd38e154bcf6..4830a51827a5 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1360,15 +1360,38 @@ static LogicalResult verify(ReturnOp op) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
+  auto trueVal = getTrueValue();
+  auto falseVal = getFalseValue();
+  if (trueVal == falseVal)
+    return trueVal;
+
   auto condition = getCondition();
 
   // select true, %0, %1 => %0
   if (matchPattern(condition, m_One()))
-    return getTrueValue();
+    return trueVal;
 
   // select false, %0, %1 => %1
   if (matchPattern(condition, m_Zero()))
-    return getFalseValue();
+    return falseVal;
+
+  if (auto cmp = dyn_cast_or_null<CmpIOp>(condition.getDefiningOp())) {
+    auto pred = cmp.predicate();
+    if (pred == mlir::CmpIPredicate::eq || pred == mlir::CmpIPredicate::ne) {
+      auto cmpLhs = cmp.lhs();
+      auto cmpRhs = cmp.rhs();
+
+      // %0 = cmpi eq, %arg0, %arg1
+      // %1 = select %0, %arg0, %arg1 => %arg1
+
+      // %0 = cmpi ne, %arg0, %arg1
+      // %1 = select %0, %arg0, %arg1 => %arg0
+
+      if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
+          (cmpRhs == trueVal && cmpLhs == falseVal))
+        return pred == mlir::CmpIPredicate::ne ? trueVal : falseVal;
+    }
+  }
   return nullptr;
 }
 

diff  --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index a6bf0c78321a..77022024ae48 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -339,3 +339,32 @@ func @subtensor_insert_output_dest_canonicalize(%arg0 : tensor<2x3xi32>, %arg1 :
 //       CHECK:   %[[GENERATE:.+]] = tensor.generate
 //       CHECK:   %[[RESULT:.+]] = subtensor_insert %[[ARG0]] into %[[GENERATE]]
 //       CHECK:   return %[[RESULT]]
+
+// -----
+
+// CHECK-LABEL: @select_same_val
+//       CHECK:   return %arg1
+func @select_same_val(%arg0: i1, %arg1: i64) -> i64 {
+  %0 = select %arg0, %arg1, %arg1 : i64
+  return %0 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @select_cmp_eq_select
+//       CHECK:   return %arg1
+func @select_cmp_eq_select(%arg0: i64, %arg1: i64) -> i64 {
+  %0 = cmpi eq, %arg0, %arg1 : i64
+  %1 = select %0, %arg0, %arg1 : i64
+  return %1 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @select_cmp_ne_select
+//       CHECK:   return %arg0
+func @select_cmp_ne_select(%arg0: i64, %arg1: i64) -> i64 {
+  %0 = cmpi ne, %arg0, %arg1 : i64
+  %1 = select %0, %arg0, %arg1 : i64
+  return %1 : i64
+}


        


More information about the Mlir-commits mailing list