[Mlir-commits] [mlir] [mlir][arith] Canonicalization patterns for `arith.select` (PR #67809)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 29 07:18:31 PDT 2023


https://github.com/peterbell10 created https://github.com/llvm/llvm-project/pull/67809

This adds the following canonicalization patterns:
- `select(pred, select(pred, a, b), c) => select(pred, a, c)`
- `select(pred, a, select(pred, b, c)) => select(pred, a, c)`
- `select(not(pred), a, b) => select(pred, b, a)`

>From 1c3b9950e391ddc4accacf34b20fb9aa74bed69b Mon Sep 17 00:00:00 2001
From: Peter Bell <peterbell10 at live.co.uk>
Date: Fri, 29 Sep 2023 00:47:57 +0100
Subject: [PATCH] [mlir][arith] Canonicalization patterns for `arith.select`

This adds the following canonicalization patterns:
- `select(pred, select(pred, a, b), c) => select(pred, a, c)`
- `select(pred, a, select(pred, b, c)) => select(pred, a, c)`
- `select(not(pred), a, b) => select(pred, b, a)`
---
 .../Dialect/Arith/IR/ArithCanonicalization.td | 20 +++++++++++++
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        |  3 +-
 mlir/test/Dialect/Arith/canonicalize.mlir     | 28 +++++++++++++++++++
 3 files changed, 50 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index f3d84d0b261e8dd..ea043d35b2cb751 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -233,6 +233,26 @@ def CmpIExtUI :
             CPred<"$0.getValue() == arith::CmpIPredicate::eq || "
                   "$0.getValue() == arith::CmpIPredicate::ne">> $pred)]>;
 
+//===----------------------------------------------------------------------===//
+// SelectOp
+//===----------------------------------------------------------------------===//
+
+// select(pred, select(pred, a, b), c) => select(pred, a, c)
+def RedundantSelectTrue :
+    Pat<(SelectOp $pred, (SelectOp $pred, $a, $b), $c),
+        (SelectOp $pred, $a, $c)>;
+
+// select(pred, a, select(pred, b, c)) => select(pred, a, c)
+def RedundantSelectFalse :
+    Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)),
+        (SelectOp $pred, $a, $c)>;
+
+// select(not(pred), a, b) => select(pred, b, a)
+def SelectNotCond :
+    Pat<(SelectOp (Arith_XOrIOp $pred, (ConstantLikeMatcher APIntAttr:$ones)), $a, $b),
+        (SelectOp $pred, $b, $a),
+        [(IsScalarOrSplatNegativeOne $ones)]>;
+
 //===----------------------------------------------------------------------===//
 // IndexCastOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index ae8a6ef350ce191..56dc92b573b2310 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2212,7 +2212,8 @@ struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
 
 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
-  results.add<SelectI1Simplify, SelectToExtUI>(context);
+  results.add<RedundantSelectFalse, RedundantSelectTrue, SelectI1Simplify,
+              SelectNotCond, SelectToExtUI>(context);
 }
 
 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 84096354e6afe33..ea55e5dec9e8377 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -128,6 +128,34 @@ func.func @selToArith(%arg0: i1, %arg1 : i1, %arg2 : i1) -> i1 {
   return %res : i1
 }
 
+// CHECK-LABEL: @redundantSelectTrue
+//       CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg1, %arg3
+//       CHECK-NEXT: return %[[res]]
+func.func @redundantSelectTrue(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
+  %0 = arith.select %arg0, %arg1, %arg2 : i32
+  %res = arith.select %arg0, %0, %arg3 : i32
+  return %res : i32
+}
+
+// CHECK-LABEL: @redundantSelectFalse
+//       CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg3, %arg2
+//       CHECK-NEXT: return %[[res]]
+func.func @redundantSelectFalse(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
+  %0 = arith.select %arg0, %arg1, %arg2 : i32
+  %res = arith.select %arg0, %arg3, %0 : i32
+  return %res : i32
+}
+
+// CHECK-LABEL: @selNotCond
+//       CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg2, %arg1
+//       CHECK-NEXT: return %[[res]]
+func.func @selNotCond(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
+  %one = arith.constant 1 : i1
+  %cond = arith.xori %arg0, %one : i1
+  %res = arith.select %cond, %arg1, %arg2 : i32
+  return %res : i32
+}
+
 // Test case: Folding of comparisons with equal operands.
 // CHECK-LABEL: @cmpi_equal_operands
 //   CHECK-DAG:   %[[T:.*]] = arith.constant true



More information about the Mlir-commits mailing list