[Mlir-commits] [mlir] [mlir][arith] Canonicalization patterns for `arith.select` (PR #67809)

Han-Chung Wang llvmlistbot at llvm.org
Thu Oct 12 16:33:23 PDT 2023


================
@@ -233,6 +233,50 @@ def CmpIExtUI :
             CPred<"$0.getValue() == arith::CmpIPredicate::eq || "
                   "$0.getValue() == arith::CmpIPredicate::ne">> $pred)]>;
 
+//===----------------------------------------------------------------------===//
+// SelectOp
+//===----------------------------------------------------------------------===//
+
+// select(not(pred), a, b) => select(pred, b, a)
+def SelectNotCond :
+    Pat<(SelectOp (Arith_XOrIOp $pred, (ConstantLikeMatcher APIntAttr:$ones)), $a, $b),
+        (SelectOp $pred, $b, $a),
+        [(IsScalarOrSplatNegativeOne $ones)]>;
+
+// select(pred, select(pred, a, b), c) => select(pred, a, c)
+def RedundantSelectTrue :
+    Pat<(SelectOp $pred, (SelectOp $pred, $a, $b), $c),
+        (SelectOp $pred, $a, $c)>;
+
+// select(pred, a, select(pred, b, c)) => select(pred, a, c)
+def RedundantSelectFalse :
+    Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)),
+        (SelectOp $pred, $a, $c)>;
+
+// select(predA, select(predB, x, y), y) => select(and(predA, predB), x, y)
+def SelectAndCond :
+    Pat<(SelectOp $predA, (SelectOp $predB, $x, $y), $y),
+        (SelectOp (Arith_AndIOp $predA, $predB), $x, $y)>;
+
+// select(predA, select(predB, y, x), y) => select(and(predA, not(predB)), x, y)
+def SelectAndNotCond :
+    Pat<(SelectOp $predA, (SelectOp $predB, $y, $x), $y),
+        (SelectOp (Arith_AndIOp $predA,
+                                (Arith_XOrIOp $predB, (Arith_ConstantOp ConstantAttr<I1Attr, "1">))),
----------------
hanhanW wrote:

Here is a failure example:

```
  func.func @foo(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> vector<4xf32> {
    %0 = arith.cmpf ult, %arg0, %arg1 : vector<4xf32>
    %1 = arith.select %0, %arg0, %arg1 : vector<4xi1>, vector<4xf32>
    %2 = arith.cmpf uno, %arg1, %arg1 : vector<4xf32>
    %3 = arith.select %2, %arg1, %1 : vector<4xi1>, vector<4xf32>
    return %3 : vector<4xf32>
  }
```

https://github.com/llvm/llvm-project/pull/67809


More information about the Mlir-commits mailing list