[Mlir-commits] [mlir] 93c7918 - [MLIR] Canonicalize/fold select %x, 1, 0 to extui

William S. Moses llvmlistbot at llvm.org
Sun Jan 2 22:26:31 PST 2022


Author: William S. Moses
Date: 2022-01-03T01:26:28-05:00
New Revision: 93c791839a42cb5d81dc198452ef486fa712a860

URL: https://github.com/llvm/llvm-project/commit/93c791839a42cb5d81dc198452ef486fa712a860
DIFF: https://github.com/llvm/llvm-project/commit/93c791839a42cb5d81dc198452ef486fa712a860.diff

LOG: [MLIR] Canonicalize/fold select %x, 1, 0 to extui

Two canonicalizations for select %x, 1, 0
  If the return type is i1, return simply the condition %x, otherwise extui %x to the return type.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D116517

Added: 
    

Modified: 
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Standard/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index de45339f89558..a1047a58ce2c5 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -840,9 +840,43 @@ struct SelectToNot : public OpRewritePattern<SelectOp> {
   }
 };
 
+//  select %arg, %c1, %c0 => extui %arg
+struct SelectToExtUI : public OpRewritePattern<SelectOp> {
+  using OpRewritePattern<SelectOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(SelectOp op,
+                                PatternRewriter &rewriter) const override {
+    // Cannot extui i1 to i1, or i1 to f32
+    if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1))
+      return failure();
+
+    // select %x, c1, %c0 => extui %arg
+    if (matchPattern(op.getTrueValue(), m_One()))
+      if (matchPattern(op.getFalseValue(), m_Zero())) {
+        rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
+                                                    op.getCondition());
+        return success();
+      }
+
+    // select %x, c0, %c1 => extui (xor %arg, true)
+    if (matchPattern(op.getTrueValue(), m_Zero()))
+      if (matchPattern(op.getFalseValue(), m_One())) {
+        rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
+            op, op.getType(),
+            rewriter.create<arith::XOrIOp>(
+                op.getLoc(), op.getCondition(),
+                rewriter.create<arith::ConstantIntOp>(
+                    op.getLoc(), 1, op.getCondition().getType())));
+        return success();
+      }
+
+    return failure();
+  }
+};
+
 void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                            MLIRContext *context) {
-  results.insert<SelectToNot>(context);
+  results.insert<SelectToNot, SelectToExtUI>(context);
 }
 
 OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
@@ -861,6 +895,12 @@ OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
   if (matchPattern(condition, m_Zero()))
     return falseVal;
 
+  // select %x, true, false => %x
+  if (getType().isInteger(1))
+    if (matchPattern(getTrueValue(), m_One()))
+      if (matchPattern(getFalseValue(), m_Zero()))
+        return condition;
+
   if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
     auto pred = cmp.getPredicate();
     if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {

diff  --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index 875d9f7bc4fa2..b44f78f96d070 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -29,6 +29,41 @@ func @select_cmp_ne_select(%arg0: i64, %arg1: i64) -> i64 {
 
 // -----
 
+// CHECK-LABEL: @select_extui
+//       CHECK:   %[[res:.+]] = arith.extui %arg0 : i1 to i64
+//       CHECK:   return %[[res]]
+func @select_extui(%arg0: i1) -> i64 {
+  %c0_i64 = arith.constant 0 : i64
+  %c1_i64 = arith.constant 1 : i64
+  %res = select %arg0, %c1_i64, %c0_i64 : i64
+  return %res : i64
+}
+
+// CHECK-LABEL: @select_extui2
+// CHECK-DAG:  %true = arith.constant true
+// CHECK-DAG:  %[[xor:.+]] = arith.xori %arg0, %true : i1
+// CHECK-DAG:  %[[res:.+]] = arith.extui %[[xor]] : i1 to i64
+//       CHECK:   return %[[res]]
+func @select_extui2(%arg0: i1) -> i64 {
+  %c0_i64 = arith.constant 0 : i64
+  %c1_i64 = arith.constant 1 : i64
+  %res = select %arg0, %c0_i64, %c1_i64 : i64
+  return %res : i64
+}
+
+// -----
+
+// CHECK-LABEL: @select_extui_i1
+//  CHECK-NEXT:   return %arg0
+func @select_extui_i1(%arg0: i1) -> i1 {
+  %c0_i1 = arith.constant false
+  %c1_i1 = arith.constant true
+  %res = select %arg0, %c1_i1, %c0_i1 : i1
+  return %res : i1
+}
+
+// -----
+
 // CHECK-LABEL: @branchCondProp
 //       CHECK:       %[[trueval:.+]] = arith.constant true
 //       CHECK:       %[[falseval:.+]] = arith.constant false


        


More information about the Mlir-commits mailing list