[Mlir-commits] [mlir] [mlir][arith] Canonicalization patterns for `arith.select` (PR #67809)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 29 15:55:38 PDT 2023
https://github.com/peterbell10 updated https://github.com/llvm/llvm-project/pull/67809
>From 437b36c5843a7636c6a1b59d3a1fb142810b59b6 Mon Sep 17 00:00:00 2001
From: Peter Bell <peterbell10 at live.co.uk>
Date: Fri, 29 Sep 2023 00:47:57 +0100
Subject: [PATCH] [mlir][arith] Canonicalization patterns for `arith.select`
This adds the following canonicalization patterns:
- Inverting condition:
- `select(not(pred), a, b) => select(pred, b, a)`
- Merging consecutive selects with the same predicate
- `select(pred, select(pred, a, b), c) => select(pred, a, c)`
- `select(pred, a, select(pred, b, c)) => select(pred, a, c)`
- Merging consecutive selects with a common value value:
- `select(predA, select(predB, a, b), b) => select(and(predA, predB), a, b)`
- `select(predA, select(predB, b, a), b) => select(and(predA, not(predB)), a, b)`
- `select(predA, a, select(predB, a, b)) => select(or(predA, predB), a, b)`
- `select(predA, a, select(predB, b, a)) => select(or(predA, not(predB)), a, b)`
---
.../Dialect/Arith/IR/ArithCanonicalization.td | 44 ++++++++++++
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 4 +-
mlir/test/Dialect/Arith/canonicalize.mlir | 72 +++++++++++++++++++
3 files changed, 119 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index f3d84d0b261e8dd..ae18f62864c5d67 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -233,6 +233,50 @@ def CmpIExtUI :
CPred<"$0.getValue() == arith::CmpIPredicate::eq || "
"$0.getValue() == arith::CmpIPredicate::ne">> $pred)]>;
+//===----------------------------------------------------------------------===//
+// SelectOp
+//===----------------------------------------------------------------------===//
+
+// select(not(pred), a, b) => select(pred, b, a)
+def SelectNotCond :
+ Pat<(SelectOp (Arith_XOrIOp $pred, (ConstantLikeMatcher APIntAttr:$ones)), $a, $b),
+ (SelectOp $pred, $b, $a),
+ [(IsScalarOrSplatNegativeOne $ones)]>;
+
+// select(pred, select(pred, a, b), c) => select(pred, a, c)
+def RedundantSelectTrue :
+ Pat<(SelectOp $pred, (SelectOp $pred, $a, $b), $c),
+ (SelectOp $pred, $a, $c)>;
+
+// select(pred, a, select(pred, b, c)) => select(pred, a, c)
+def RedundantSelectFalse :
+ Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)),
+ (SelectOp $pred, $a, $c)>;
+
+// select(predA, select(predB, a, b), b) => select(and(predA, predB), a, b)
+def SelectAndCond :
+ Pat<(SelectOp $predA, (SelectOp $predB, $a, $b), $b),
+ (SelectOp (Arith_AndIOp $predA, $predB), $a, $b)>;
+
+// select(predA, select(predB, b, a), b) => select(and(predA, not(predB)), a, b)
+def SelectAndNotCond :
+ Pat<(SelectOp $predA, (SelectOp $predB, $b, $a), $b),
+ (SelectOp (Arith_AndIOp $predA,
+ (Arith_XOrIOp $predB, (Arith_ConstantOp ConstantAttr<I1Attr, "1">))),
+ $a, $b)>;
+
+// select(predA, a, select(predB, a, b)) => select(or(predA, predB), a, b)
+def SelectOrCond :
+ Pat<(SelectOp $predA, $a, (SelectOp $predB, $a, $b)),
+ (SelectOp (Arith_OrIOp $predA, $predB), $a, $b)>;
+
+// select(predA, a, select(predB, b, a)) => select(or(predA, not(predB)), a, b)
+def SelectOrNotCond :
+ Pat<(SelectOp $predA, $a, (SelectOp $predB, $b, $a)),
+ (SelectOp (Arith_OrIOp $predA,
+ (Arith_XOrIOp $predB, (Arith_ConstantOp ConstantAttr<I1Attr, "1">))),
+ $a, $b)>;
+
//===----------------------------------------------------------------------===//
// IndexCastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index ae8a6ef350ce191..0ecc288f3b07701 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2212,7 +2212,9 @@ struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<SelectI1Simplify, SelectToExtUI>(context);
+ results.add<RedundantSelectFalse, RedundantSelectTrue, SelectI1Simplify,
+ SelectAndCond, SelectAndNotCond, SelectOrCond, SelectOrNotCond,
+ SelectNotCond, SelectToExtUI>(context);
}
OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 84096354e6afe33..23d4a4ca813935d 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -128,6 +128,78 @@ func.func @selToArith(%arg0: i1, %arg1 : i1, %arg2 : i1) -> i1 {
return %res : i1
}
+// CHECK-LABEL: @redundantSelectTrue
+// CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg1, %arg3
+// CHECK-NEXT: return %[[res]]
+func.func @redundantSelectTrue(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
+ %0 = arith.select %arg0, %arg1, %arg2 : i32
+ %res = arith.select %arg0, %0, %arg3 : i32
+ return %res : i32
+}
+
+// CHECK-LABEL: @redundantSelectFalse
+// CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg3, %arg2
+// CHECK-NEXT: return %[[res]]
+func.func @redundantSelectFalse(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
+ %0 = arith.select %arg0, %arg1, %arg2 : i32
+ %res = arith.select %arg0, %arg3, %0 : i32
+ return %res : i32
+}
+
+// CHECK-LABEL: @selNotCond
+// CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg2, %arg1
+// CHECK-NEXT: return %[[res]]
+func.func @selNotCond(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
+ %one = arith.constant 1 : i1
+ %cond = arith.xori %arg0, %one : i1
+ %res = arith.select %cond, %arg1, %arg2 : i32
+ return %res : i32
+}
+
+// CHECK-LABEL: @selAndCond
+// CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %arg0
+// CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg2, %arg3
+// CHECK-NEXT: return %[[res]]
+func.func @selAndCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
+ %sel = arith.select %arg0, %arg2, %arg3 : i32
+ %res = arith.select %arg1, %sel, %arg3 : i32
+ return %res : i32
+}
+
+// CHECK-LABEL: @selAndNotCond
+// CHECK-NEXT: %[[one:.+]] = arith.constant true
+// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
+// CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %[[not]]
+// CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg3, %arg2
+// CHECK-NEXT: return %[[res]]
+func.func @selAndNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
+ %sel = arith.select %arg0, %arg2, %arg3 : i32
+ %res = arith.select %arg1, %sel, %arg2 : i32
+ return %res : i32
+}
+
+// CHECK-LABEL: @selOrCond
+// CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %arg0
+// CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg2, %arg3
+// CHECK-NEXT: return %[[res]]
+func.func @selOrCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
+ %sel = arith.select %arg0, %arg2, %arg3 : i32
+ %res = arith.select %arg1, %arg2, %sel : i32
+ return %res : i32
+}
+
+// CHECK-LABEL: @selOrNotCond
+// CHECK-NEXT: %[[one:.+]] = arith.constant true
+// CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
+// CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %[[not]]
+// CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg3, %arg2
+// CHECK-NEXT: return %[[res]]
+func.func @selOrNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
+ %sel = arith.select %arg0, %arg2, %arg3 : i32
+ %res = arith.select %arg1, %arg3, %sel : i32
+ return %res : i32
+}
+
// Test case: Folding of comparisons with equal operands.
// CHECK-LABEL: @cmpi_equal_operands
// CHECK-DAG: %[[T:.*]] = arith.constant true
More information about the Mlir-commits
mailing list