[Mlir-commits] [mlir] [mlir][arith] Fix `arith.select` canonicalization patterns (PR #84685)

Fehr Mathieu llvmlistbot at llvm.org
Wed Mar 13 08:51:43 PDT 2024


https://github.com/math-fehr updated https://github.com/llvm/llvm-project/pull/84685

>From 96ae236f104f367ac70a74ad7a406161ac51f7c5 Mon Sep 17 00:00:00 2001
From: Mathieu Fehr <mathieu.fehr at gmail.com>
Date: Sun, 10 Mar 2024 17:42:31 +0000
Subject: [PATCH] [mlir][arith] Fix `arith.select` canonicalization patterns

Because `arith.select` does not propagate poison of the second
or third operand depending on the condition, some canonicalization
patterns were incorrect. This patch removes these incorrect patterns,
and adds a new pattern to fix the case of `i1` select with constants.

Patterns that are removed:
  * select(predA, select(predB, x, y), y) => select(and(predA, predB), x, y)
  * select(predA, select(predB, y, x), y) => select(and(predA, not(predB)), x, y)
  * select(predA, x, select(predB, x, y)) => select(or(predA, predB), x, y)
  * select(predA, x, select(predB, y, x)) => select(or(predA, not(predB)), x, y)
  * arith.select %arg, %x, %y : i1        => and(%arg, %x) or and(!%arg, %y)

The first two patterns are incorrect when `predB` is poison and `predA`
is false, as a non-poison `y` gets compiled to `poison`. The next two
patterns are incorrect when `predB` is poison and `predA` is true, as
a non-poison `x` gets compiled to `poison`. The last pattern is incorrect
as it propagates poison from all operands afer compilation.
---
 .../Dialect/Arith/IR/ArithCanonicalization.td | 34 ++------
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        | 35 +-------
 mlir/test/Dialect/Arith/canonicalize.mlir     | 80 -------------------
 3 files changed, 8 insertions(+), 141 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 11c4a29718e1d9..caca2ff81964f7 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -253,9 +253,6 @@ def CmpIExtUI :
 // SelectOp
 //===----------------------------------------------------------------------===//
 
-def GetScalarOrVectorTrueAttribute :
-  NativeCodeCall<"cast<TypedAttr>(getBoolAttribute($0.getType(), true))">;
-
 // select(not(pred), a, b) => select(pred, b, a)
 def SelectNotCond :
     Pat<(SelectOp (Arith_XOrIOp $pred, (ConstantLikeMatcher APIntAttr:$ones)), $a, $b),
@@ -272,31 +269,12 @@ def RedundantSelectFalse :
     Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)),
         (SelectOp $pred, $a, $c)>;
 
-// select(predA, select(predB, x, y), y) => select(and(predA, predB), x, y)
-def SelectAndCond :
-    Pat<(SelectOp $predA, (SelectOp $predB, $x, $y), $y),
-        (SelectOp (Arith_AndIOp $predA, $predB), $x, $y)>;
-
-// select(predA, select(predB, y, x), y) => select(and(predA, not(predB)), x, y)
-def SelectAndNotCond :
-    Pat<(SelectOp $predA, (SelectOp $predB, $y, $x), $y),
-        (SelectOp (Arith_AndIOp $predA,
-                                (Arith_XOrIOp $predB,
-                                (Arith_ConstantOp (GetScalarOrVectorTrueAttribute $predB)))),
-                  $x, $y)>;
-
-// select(predA, x, select(predB, x, y)) => select(or(predA, predB), x, y)
-def SelectOrCond :
-    Pat<(SelectOp $predA, $x, (SelectOp $predB, $x, $y)),
-        (SelectOp (Arith_OrIOp $predA, $predB), $x, $y)>;
-
-// select(predA, x, select(predB, y, x)) => select(or(predA, not(predB)), x, y)
-def SelectOrNotCond :
-    Pat<(SelectOp $predA, $x, (SelectOp $predB, $y, $x)),
-        (SelectOp (Arith_OrIOp $predA,
-                               (Arith_XOrIOp $predB,
-                               (Arith_ConstantOp (GetScalarOrVectorTrueAttribute $predB)))),
-                  $x, $y)>;
+// select(pred, false, true) => not(pred) 
+def SelectI1ToNot :
+    Pat<(SelectOp $pred,
+                  (ConstantLikeMatcher ConstantAttr<I1Attr, "0">),
+                  (ConstantLikeMatcher ConstantAttr<I1Attr, "1">)),
+        (Arith_XOrIOp $pred, (Arith_ConstantOp ConstantAttr<I1Attr, "1">))>;
 
 //===----------------------------------------------------------------------===//
 // IndexCastOp
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 0f71c19c23b654..9f64a07f31e3af 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -969,7 +969,6 @@ OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
       [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
 }
 
-
 //===----------------------------------------------------------------------===//
 // MaxSIOp
 //===----------------------------------------------------------------------===//
@@ -2173,35 +2172,6 @@ void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 // SelectOp
 //===----------------------------------------------------------------------===//
 
-// Transforms a select of a boolean to arithmetic operations
-//
-//  arith.select %arg, %x, %y : i1
-//
-//  becomes
-//
-//  and(%arg, %x) or and(!%arg, %y)
-struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> {
-  using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(arith::SelectOp op,
-                                PatternRewriter &rewriter) const override {
-    if (!op.getType().isInteger(1))
-      return failure();
-
-    Value falseConstant =
-        rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
-    Value notCondition = rewriter.create<arith::XOrIOp>(
-        op.getLoc(), op.getCondition(), falseConstant);
-
-    Value trueVal = rewriter.create<arith::AndIOp>(
-        op.getLoc(), op.getCondition(), op.getTrueValue());
-    Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
-                                                    op.getFalseValue());
-    rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
-    return success();
-  }
-};
-
 //  select %arg, %c1, %c0 => extui %arg
 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
@@ -2238,9 +2208,8 @@ struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
 
 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
-  results.add<RedundantSelectFalse, RedundantSelectTrue, SelectI1Simplify,
-              SelectAndCond, SelectAndNotCond, SelectOrCond, SelectOrNotCond,
-              SelectNotCond, SelectToExtUI>(context);
+  results.add<RedundantSelectFalse, RedundantSelectTrue, SelectNotCond,
+              SelectI1ToNot, SelectToExtUI>(context);
 }
 
 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index cb98a10048a309..bdc6c91d926775 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -116,18 +116,6 @@ func.func @selToNot(%arg0: i1) -> i1 {
   return %res : i1
 }
 
-// CHECK-LABEL: @selToArith
-//       CHECK-NEXT:       %[[trueval:.+]] = arith.constant true
-//       CHECK-NEXT:       %[[notcmp:.+]] = arith.xori %arg0, %[[trueval]] : i1
-//       CHECK-NEXT:       %[[condtrue:.+]] = arith.andi %arg0, %arg1 : i1
-//       CHECK-NEXT:       %[[condfalse:.+]] = arith.andi %[[notcmp]], %arg2 : i1
-//       CHECK-NEXT:       %[[res:.+]] = arith.ori %[[condtrue]], %[[condfalse]] : i1
-//       CHECK:   return %[[res]]
-func.func @selToArith(%arg0: i1, %arg1 : i1, %arg2 : i1) -> i1 {
-  %res = arith.select %arg0, %arg1, %arg2 : i1
-  return %res : i1
-}
-
 // CHECK-LABEL: @redundantSelectTrue
 //       CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg1, %arg3
 //       CHECK-NEXT: return %[[res]]
@@ -160,74 +148,6 @@ func.func @selNotCond(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 :
   return %res1, %res2 : i32, i32
 }
 
-// CHECK-LABEL: @selAndCond
-//       CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %arg0
-//       CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg2, %arg3
-//       CHECK-NEXT: return %[[res]]
-func.func @selAndCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
-  %sel = arith.select %arg0, %arg2, %arg3 : i32
-  %res = arith.select %arg1, %sel, %arg3 : i32
-  return %res : i32
-}
-
-// CHECK-LABEL: @selAndNotCond
-//       CHECK-NEXT: %[[one:.+]] = arith.constant true
-//       CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
-//       CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %[[not]]
-//       CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg3, %arg2
-//       CHECK-NEXT: return %[[res]]
-func.func @selAndNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
-  %sel = arith.select %arg0, %arg2, %arg3 : i32
-  %res = arith.select %arg1, %sel, %arg2 : i32
-  return %res : i32
-}
-
-// CHECK-LABEL: @selAndNotCondVec
-//       CHECK-NEXT: %[[one:.+]] = arith.constant dense<true> : vector<4xi1>
-//       CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
-//       CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %[[not]]
-//       CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg3, %arg2
-//       CHECK-NEXT: return %[[res]]
-func.func @selAndNotCondVec(%arg0: vector<4xi1>, %arg1: vector<4xi1>, %arg2 : vector<4xi32>, %arg3 : vector<4xi32>) -> vector<4xi32> {
-  %sel = arith.select %arg0, %arg2, %arg3 : vector<4xi1>, vector<4xi32>
-  %res = arith.select %arg1, %sel, %arg2 : vector<4xi1>, vector<4xi32>
-  return %res : vector<4xi32>
-}
-
-// CHECK-LABEL: @selOrCond
-//       CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %arg0
-//       CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg2, %arg3
-//       CHECK-NEXT: return %[[res]]
-func.func @selOrCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
-  %sel = arith.select %arg0, %arg2, %arg3 : i32
-  %res = arith.select %arg1, %arg2, %sel : i32
-  return %res : i32
-}
-
-// CHECK-LABEL: @selOrNotCond
-//       CHECK-NEXT: %[[one:.+]] = arith.constant true
-//       CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
-//       CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %[[not]]
-//       CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg3, %arg2
-//       CHECK-NEXT: return %[[res]]
-func.func @selOrNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
-  %sel = arith.select %arg0, %arg2, %arg3 : i32
-  %res = arith.select %arg1, %arg3, %sel : i32
-  return %res : i32
-}
-
-// CHECK-LABEL: @selOrNotCondVec
-//       CHECK-NEXT: %[[one:.+]] = arith.constant dense<true> : vector<4xi1>
-//       CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
-//       CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %[[not]]
-//       CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg3, %arg2
-//       CHECK-NEXT: return %[[res]]
-func.func @selOrNotCondVec(%arg0: vector<4xi1>, %arg1: vector<4xi1>, %arg2 : vector<4xi32>, %arg3 : vector<4xi32>) -> vector<4xi32> {
-  %sel = arith.select %arg0, %arg2, %arg3 : vector<4xi1>, vector<4xi32>
-  %res = arith.select %arg1, %arg3, %sel : vector<4xi1>, vector<4xi32>
-  return %res : vector<4xi32>
-}
-
 // Test case: Folding of comparisons with equal operands.
 // CHECK-LABEL: @cmpi_equal_operands
 //   CHECK-DAG:   %[[T:.*]] = arith.constant true



More information about the Mlir-commits mailing list