[flang-commits] [flang] Reland "[mlir][arith] Canonicalization patterns for `arith.select` (#67809)" (PR #68941)

Han-Chung Wang via flang-commits flang-commits at lists.llvm.org
Fri Oct 13 10:46:04 PDT 2023


https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/68941

>From 877111a139b2f01037fdbe7c0cb120a2f4e64562 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 12 Oct 2023 17:14:29 -0700
Subject: [PATCH 1/2] Reland "[mlir][arith] Canonicalization patterns for
 `arith.select` (#67809)"

This cherry-picks the changes in
llvm-project/5bf701a6687a46fd898621f5077959ff202d716b with limiting
types to i1.
---
 .../Dialect/Arith/IR/ArithCanonicalization.td | 46 +++++++++++
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        |  4 +-
 mlir/test/Dialect/Arith/canonicalize.mlir     | 76 +++++++++++++++++++
 3 files changed, 125 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index f3d84d0b261e8dd..817de0e06c661b9 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -233,6 +233,52 @@ def CmpIExtUI :
             CPred<"$0.getValue() == arith::CmpIPredicate::eq || "
                   "$0.getValue() == arith::CmpIPredicate::ne">> $pred)]>;
 
+//===----------------------------------------------------------------------===//
+// SelectOp
+//===----------------------------------------------------------------------===//
+
+// select(not(pred), a, b) => select(pred, b, a)
+def SelectNotCond :
+    Pat<(SelectOp (Arith_XOrIOp $pred, (ConstantLikeMatcher APIntAttr:$ones)), $a, $b),
+        (SelectOp $pred, $b, $a),
+        [(IsScalarOrSplatNegativeOne $ones)]>;
+
+// select(pred, select(pred, a, b), c) => select(pred, a, c)
+def RedundantSelectTrue :
+    Pat<(SelectOp $pred, (SelectOp $pred, $a, $b), $c),
+        (SelectOp $pred, $a, $c)>;
+
+// select(pred, a, select(pred, b, c)) => select(pred, a, c)
+def RedundantSelectFalse :
+    Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)),
+        (SelectOp $pred, $a, $c)>;
+
+// select(predA, select(predB, x, y), y) => select(and(predA, predB), x, y)
+def SelectAndCond :
+    Pat<(SelectOp $predA, (SelectOp $predB, $x, $y), $y),
+        (SelectOp (Arith_AndIOp $predA, $predB), $x, $y)>;
+
+// select(predA, select(predB, y, x), y) => select(and(predA, not(predB)), x, y)
+def SelectAndNotCond :
+    Pat<(SelectOp $predA, (SelectOp $predB, $y, $x), $y),
+        (SelectOp (Arith_AndIOp $predA,
+                                (Arith_XOrIOp $predB, (Arith_ConstantOp ConstantAttr<I1Attr, "1">))),
+                  $x, $y),
+        [(Constraint<CPred<"$0.getType() == $_builder.getI1Type()">> $predB)]>;
+
+// select(predA, x, select(predB, x, y)) => select(or(predA, predB), x, y)
+def SelectOrCond :
+    Pat<(SelectOp $predA, $x, (SelectOp $predB, $x, $y)),
+        (SelectOp (Arith_OrIOp $predA, $predB), $x, $y)>;
+
+// select(predA, x, select(predB, y, x)) => select(or(predA, not(predB)), x, y)
+def SelectOrNotCond :
+    Pat<(SelectOp $predA, $x, (SelectOp $predB, $y, $x)),
+        (SelectOp (Arith_OrIOp $predA,
+                               (Arith_XOrIOp $predB, (Arith_ConstantOp ConstantAttr<I1Attr, "1">))),
+                  $x, $y),
+        [(Constraint<CPred<"$0.getType() == $_builder.getI1Type()">> $predB)]>;
+
 //===----------------------------------------------------------------------===//
 // IndexCastOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index ae8a6ef350ce191..0ecc288f3b07701 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2212,7 +2212,9 @@ struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
 
 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
-  results.add<SelectI1Simplify, SelectToExtUI>(context);
+  results.add<RedundantSelectFalse, RedundantSelectTrue, SelectI1Simplify,
+              SelectAndCond, SelectAndNotCond, SelectOrCond, SelectOrNotCond,
+              SelectNotCond, SelectToExtUI>(context);
 }
 
 OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index f697f3d01458eee..1b0547c9e8f804a 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -128,6 +128,82 @@ func.func @selToArith(%arg0: i1, %arg1 : i1, %arg2 : i1) -> i1 {
   return %res : i1
 }
 
+// CHECK-LABEL: @redundantSelectTrue
+//       CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg1, %arg3
+//       CHECK-NEXT: return %[[res]]
+func.func @redundantSelectTrue(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
+  %0 = arith.select %arg0, %arg1, %arg2 : i32
+  %res = arith.select %arg0, %0, %arg3 : i32
+  return %res : i32
+}
+
+// CHECK-LABEL: @redundantSelectFalse
+//       CHECK-NEXT: %[[res:.+]] = arith.select %arg0, %arg3, %arg2
+//       CHECK-NEXT: return %[[res]]
+func.func @redundantSelectFalse(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 {
+  %0 = arith.select %arg0, %arg1, %arg2 : i32
+  %res = arith.select %arg0, %arg3, %0 : i32
+  return %res : i32
+}
+
+// CHECK-LABEL: @selNotCond
+//       CHECK-NEXT: %[[res1:.+]] = arith.select %arg0, %arg2, %arg1
+//       CHECK-NEXT: %[[res2:.+]] = arith.select %arg0, %arg4, %arg3
+//       CHECK-NEXT: return %[[res1]], %[[res2]]
+func.func @selNotCond(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : i32) -> (i32, i32) {
+  %one = arith.constant 1 : i1
+  %cond1 = arith.xori %arg0, %one : i1
+  %cond2 = arith.xori %one, %arg0 : i1
+
+  %res1 = arith.select %cond1, %arg1, %arg2 : i32
+  %res2 = arith.select %cond2, %arg3, %arg4 : i32
+  return %res1, %res2 : i32, i32
+}
+
+// CHECK-LABEL: @selAndCond
+//       CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %arg0
+//       CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg2, %arg3
+//       CHECK-NEXT: return %[[res]]
+func.func @selAndCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
+  %sel = arith.select %arg0, %arg2, %arg3 : i32
+  %res = arith.select %arg1, %sel, %arg3 : i32
+  return %res : i32
+}
+
+// CHECK-LABEL: @selAndNotCond
+//       CHECK-NEXT: %[[one:.+]] = arith.constant true
+//       CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
+//       CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %[[not]]
+//       CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg3, %arg2
+//       CHECK-NEXT: return %[[res]]
+func.func @selAndNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
+  %sel = arith.select %arg0, %arg2, %arg3 : i32
+  %res = arith.select %arg1, %sel, %arg2 : i32
+  return %res : i32
+}
+
+// CHECK-LABEL: @selOrCond
+//       CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %arg0
+//       CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg2, %arg3
+//       CHECK-NEXT: return %[[res]]
+func.func @selOrCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
+  %sel = arith.select %arg0, %arg2, %arg3 : i32
+  %res = arith.select %arg1, %arg2, %sel : i32
+  return %res : i32
+}
+
+// CHECK-LABEL: @selOrNotCond
+//       CHECK-NEXT: %[[one:.+]] = arith.constant true
+//       CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
+//       CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %[[not]]
+//       CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg3, %arg2
+//       CHECK-NEXT: return %[[res]]
+func.func @selOrNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
+  %sel = arith.select %arg0, %arg2, %arg3 : i32
+  %res = arith.select %arg1, %arg3, %sel : i32
+  return %res : i32
+}
+
 // Test case: Folding of comparisons with equal operands.
 // CHECK-LABEL: @cmpi_equal_operands
 //   CHECK-DAG:   %[[T:.*]] = arith.constant true

>From 506e0c83d65845c62737bc915878ae47008bbc28 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Fri, 13 Oct 2023 10:45:11 -0700
Subject: [PATCH 2/2] extend patterns to handle vector types

---
 .../Dialect/Arith/IR/ArithCanonicalization.td | 15 +++++++-----
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        | 18 +++++++-------
 mlir/test/Dialect/Arith/canonicalize.mlir     | 24 +++++++++++++++++++
 3 files changed, 42 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 817de0e06c661b9..ef951647ccd1464 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -237,6 +237,9 @@ def CmpIExtUI :
 // SelectOp
 //===----------------------------------------------------------------------===//
 
+def GetScalarOrVectorTrueAttribute :
+  NativeCodeCall<"cast<TypedAttr>(getBoolAttribute($0.getType(), true))">;
+
 // select(not(pred), a, b) => select(pred, b, a)
 def SelectNotCond :
     Pat<(SelectOp (Arith_XOrIOp $pred, (ConstantLikeMatcher APIntAttr:$ones)), $a, $b),
@@ -262,9 +265,9 @@ def SelectAndCond :
 def SelectAndNotCond :
     Pat<(SelectOp $predA, (SelectOp $predB, $y, $x), $y),
         (SelectOp (Arith_AndIOp $predA,
-                                (Arith_XOrIOp $predB, (Arith_ConstantOp ConstantAttr<I1Attr, "1">))),
-                  $x, $y),
-        [(Constraint<CPred<"$0.getType() == $_builder.getI1Type()">> $predB)]>;
+                                (Arith_XOrIOp $predB,
+                                (Arith_ConstantOp (GetScalarOrVectorTrueAttribute $predB)))),
+                  $x, $y)>;
 
 // select(predA, x, select(predB, x, y)) => select(or(predA, predB), x, y)
 def SelectOrCond :
@@ -275,9 +278,9 @@ def SelectOrCond :
 def SelectOrNotCond :
     Pat<(SelectOp $predA, $x, (SelectOp $predB, $y, $x)),
         (SelectOp (Arith_OrIOp $predA,
-                               (Arith_XOrIOp $predB, (Arith_ConstantOp ConstantAttr<I1Attr, "1">))),
-                  $x, $y),
-        [(Constraint<CPred<"$0.getType() == $_builder.getI1Type()">> $predB)]>;
+                               (Arith_XOrIOp $predB,
+                               (Arith_ConstantOp (GetScalarOrVectorTrueAttribute $predB)))),
+                  $x, $y)>;
 
 //===----------------------------------------------------------------------===//
 // IndexCastOp
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 0ecc288f3b07701..02bab31971dcbe4 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -113,6 +113,14 @@ static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
   return failure();
 }
 
+static Attribute getBoolAttribute(Type type, bool value) {
+  auto boolAttr = BoolAttr::get(type.getContext(), value);
+  ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
+  if (!shapedType)
+    return boolAttr;
+  return DenseElementsAttr::get(shapedType, boolAttr);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd canonicalization patterns
 //===----------------------------------------------------------------------===//
@@ -1696,14 +1704,6 @@ static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
   llvm_unreachable("unknown cmpi predicate kind");
 }
 
-static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
-  auto boolAttr = BoolAttr::get(ctx, value);
-  ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
-  if (!shapedType)
-    return boolAttr;
-  return DenseElementsAttr::get(shapedType, boolAttr);
-}
-
 static std::optional<int64_t> getIntegerWidth(Type t) {
   if (auto intType = llvm::dyn_cast<IntegerType>(t)) {
     return intType.getWidth();
@@ -1718,7 +1718,7 @@ OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
   // cmpi(pred, x, x)
   if (getLhs() == getRhs()) {
     auto val = applyCmpPredicateToEqualOperands(getPredicate());
-    return getBoolAttribute(getType(), getContext(), val);
+    return getBoolAttribute(getType(), val);
   }
 
   if (matchPattern(adaptor.getRhs(), m_Zero())) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 1b0547c9e8f804a..abe9737b25443e8 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -182,6 +182,18 @@ func.func @selAndNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32
   return %res : i32
 }
 
+// CHECK-LABEL: @selAndNotCondVec
+//       CHECK-NEXT: %[[one:.+]] = arith.constant dense<true> : vector<4xi1>
+//       CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
+//       CHECK-NEXT: %[[and:.+]] = arith.andi %arg1, %[[not]]
+//       CHECK-NEXT: %[[res:.+]] = arith.select %[[and]], %arg3, %arg2
+//       CHECK-NEXT: return %[[res]]
+func.func @selAndNotCondVec(%arg0: vector<4xi1>, %arg1: vector<4xi1>, %arg2 : vector<4xi32>, %arg3 : vector<4xi32>) -> vector<4xi32> {
+  %sel = arith.select %arg0, %arg2, %arg3 : vector<4xi1>, vector<4xi32>
+  %res = arith.select %arg1, %sel, %arg2 : vector<4xi1>, vector<4xi32>
+  return %res : vector<4xi32>
+}
+
 // CHECK-LABEL: @selOrCond
 //       CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %arg0
 //       CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg2, %arg3
@@ -204,6 +216,18 @@ func.func @selOrNotCond(%arg0: i1, %arg1: i1, %arg2 : i32, %arg3 : i32) -> i32 {
   return %res : i32
 }
 
+// CHECK-LABEL: @selOrNotCondVec
+//       CHECK-NEXT: %[[one:.+]] = arith.constant dense<true> : vector<4xi1>
+//       CHECK-NEXT: %[[not:.+]] = arith.xori %arg0, %[[one]]
+//       CHECK-NEXT: %[[or:.+]] = arith.ori %arg1, %[[not]]
+//       CHECK-NEXT: %[[res:.+]] = arith.select %[[or]], %arg3, %arg2
+//       CHECK-NEXT: return %[[res]]
+func.func @selOrNotCondVec(%arg0: vector<4xi1>, %arg1: vector<4xi1>, %arg2 : vector<4xi32>, %arg3 : vector<4xi32>) -> vector<4xi32> {
+  %sel = arith.select %arg0, %arg2, %arg3 : vector<4xi1>, vector<4xi32>
+  %res = arith.select %arg1, %arg3, %sel : vector<4xi1>, vector<4xi32>
+  return %res : vector<4xi32>
+}
+
 // Test case: Folding of comparisons with equal operands.
 // CHECK-LABEL: @cmpi_equal_operands
 //   CHECK-DAG:   %[[T:.*]] = arith.constant true



More information about the flang-commits mailing list