[Mlir-commits] [mlir] 7219b31 - [mlir] Additional folding for SelectOp
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Mar 20 03:42:24 PDT 2021
Author: Butygin
Date: 2021-03-20T13:40:42+03:00
New Revision: 7219b31d40f14604c669d633b014d0cc8b707cf3
URL: https://github.com/llvm/llvm-project/commit/7219b31d40f14604c669d633b014d0cc8b707cf3
DIFF: https://github.com/llvm/llvm-project/commit/7219b31d40f14604c669d633b014d0cc8b707cf3.diff
LOG: [mlir] Additional folding for SelectOp
* Fold SelectOp when both true and false args are same SSA value
* Fold some cmp + select patterns
Differential Revision: https://reviews.llvm.org/D98576
Added:
Modified:
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Standard/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index bd38e154bcf6..4830a51827a5 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1360,15 +1360,38 @@ static LogicalResult verify(ReturnOp op) {
//===----------------------------------------------------------------------===//
OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
+ auto trueVal = getTrueValue();
+ auto falseVal = getFalseValue();
+ if (trueVal == falseVal)
+ return trueVal;
+
auto condition = getCondition();
// select true, %0, %1 => %0
if (matchPattern(condition, m_One()))
- return getTrueValue();
+ return trueVal;
// select false, %0, %1 => %1
if (matchPattern(condition, m_Zero()))
- return getFalseValue();
+ return falseVal;
+
+ if (auto cmp = dyn_cast_or_null<CmpIOp>(condition.getDefiningOp())) {
+ auto pred = cmp.predicate();
+ if (pred == mlir::CmpIPredicate::eq || pred == mlir::CmpIPredicate::ne) {
+ auto cmpLhs = cmp.lhs();
+ auto cmpRhs = cmp.rhs();
+
+ // %0 = cmpi eq, %arg0, %arg1
+ // %1 = select %0, %arg0, %arg1 => %arg1
+
+ // %0 = cmpi ne, %arg0, %arg1
+ // %1 = select %0, %arg0, %arg1 => %arg0
+
+ if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
+ (cmpRhs == trueVal && cmpLhs == falseVal))
+ return pred == mlir::CmpIPredicate::ne ? trueVal : falseVal;
+ }
+ }
return nullptr;
}
diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index a6bf0c78321a..77022024ae48 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -339,3 +339,32 @@ func @subtensor_insert_output_dest_canonicalize(%arg0 : tensor<2x3xi32>, %arg1 :
// CHECK: %[[GENERATE:.+]] = tensor.generate
// CHECK: %[[RESULT:.+]] = subtensor_insert %[[ARG0]] into %[[GENERATE]]
// CHECK: return %[[RESULT]]
+
+// -----
+
+// CHECK-LABEL: @select_same_val
+// CHECK: return %arg1
+func @select_same_val(%arg0: i1, %arg1: i64) -> i64 {
+ %0 = select %arg0, %arg1, %arg1 : i64
+ return %0 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @select_cmp_eq_select
+// CHECK: return %arg1
+func @select_cmp_eq_select(%arg0: i64, %arg1: i64) -> i64 {
+ %0 = cmpi eq, %arg0, %arg1 : i64
+ %1 = select %0, %arg0, %arg1 : i64
+ return %1 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @select_cmp_ne_select
+// CHECK: return %arg0
+func @select_cmp_ne_select(%arg0: i64, %arg1: i64) -> i64 {
+ %0 = cmpi ne, %arg0, %arg1 : i64
+ %1 = select %0, %arg0, %arg1 : i64
+ return %1 : i64
+}
More information about the Mlir-commits
mailing list