[Mlir-commits] [mlir] a02af37 - [MLIR] Generalize select to arithmetic canonicalization

William S. Moses llvmlistbot at llvm.org
Mon Jan 10 08:50:21 PST 2022


Author: William S. Moses
Date: 2022-01-10T11:50:17-05:00
New Revision: a02af37560ff5aa22dcef5735ef25eaf58eaaf64

URL: https://github.com/llvm/llvm-project/commit/a02af37560ff5aa22dcef5735ef25eaf58eaaf64
DIFF: https://github.com/llvm/llvm-project/commit/a02af37560ff5aa22dcef5735ef25eaf58eaaf64.diff

LOG: [MLIR] Generalize select to arithmetic canonicalization

Given a select whose result is an i1, we can eliminate the conditional in the select completely by adding a few arithmetic operations.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D116839

Added: 
    

Modified: 
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Standard/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index a1047a58ce2c5..ed6698ba281c7 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -813,29 +813,31 @@ static LogicalResult verify(ReturnOp op) {
 // SelectOp
 //===----------------------------------------------------------------------===//
 
-// Transforms a select to a not, where relevant.
+// Transforms a select of a boolean to arithmetic operations
 //
-//  select %arg, %false, %true
+//  select %arg, %x, %y : i1
 //
 //  becomes
 //
-//  xor %arg, %true
-struct SelectToNot : public OpRewritePattern<SelectOp> {
+//  and(%arg, %x) or and(!%arg, %y)
+struct SelectI1Simplify : public OpRewritePattern<SelectOp> {
   using OpRewritePattern<SelectOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(SelectOp op,
                                 PatternRewriter &rewriter) const override {
-    if (!matchPattern(op.getTrueValue(), m_Zero()))
-      return failure();
-
-    if (!matchPattern(op.getFalseValue(), m_One()))
-      return failure();
-
     if (!op.getType().isInteger(1))
       return failure();
 
-    rewriter.replaceOpWithNewOp<arith::XOrIOp>(op, op.getCondition(),
-                                               op.getFalseValue());
+    Value falseConstant =
+        rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
+    Value notCondition = rewriter.create<arith::XOrIOp>(
+        op.getLoc(), op.getCondition(), falseConstant);
+
+    Value trueVal = rewriter.create<arith::AndIOp>(
+        op.getLoc(), op.getCondition(), op.getTrueValue());
+    Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
+                                                    op.getFalseValue());
+    rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
     return success();
   }
 };
@@ -876,7 +878,7 @@ struct SelectToExtUI : public OpRewritePattern<SelectOp> {
 
 void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
                                            MLIRContext *context) {
-  results.insert<SelectToNot, SelectToExtUI>(context);
+  results.insert<SelectI1Simplify, SelectToExtUI>(context);
 }
 
 OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {

diff  --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index b44f78f96d070..67d95ce0d194c 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -88,10 +88,23 @@ func @branchCondProp(%arg0: i1) {
 
 // CHECK-LABEL: @selToNot
 //       CHECK:       %[[trueval:.+]] = arith.constant true
-//       CHECK:       %{{.+}} = arith.xori %arg0, %[[trueval]] : i1
+//       CHECK:       %[[res:.+]] = arith.xori %arg0, %[[trueval]] : i1
+//       CHECK:   return %[[res]]
 func @selToNot(%arg0: i1) -> i1 {
   %true = arith.constant true
   %false = arith.constant false
   %res = select %arg0, %false, %true : i1
   return %res : i1
 }
+
+// CHECK-LABEL: @selToArith
+//       CHECK-NEXT:       %[[trueval:.+]] = arith.constant true
+//       CHECK-NEXT:       %[[notcmp:.+]] = arith.xori %arg0, %[[trueval]] : i1
+//       CHECK-NEXT:       %[[condtrue:.+]] = arith.andi %arg0, %arg1 : i1
+//       CHECK-NEXT:       %[[condfalse:.+]] = arith.andi %[[notcmp]], %arg2 : i1
+//       CHECK-NEXT:       %[[res:.+]] = arith.ori %[[condtrue]], %[[condfalse]] : i1
+//       CHECK:   return %[[res]]
+func @selToArith(%arg0: i1, %arg1 : i1, %arg2 : i1) -> i1 {
+  %res = select %arg0, %arg1, %arg2 : i1
+  return %res : i1
+}


        


More information about the Mlir-commits mailing list