[Mlir-commits] [mlir] a02af37 - [MLIR] Generalize select to arithmetic canonicalization
William S. Moses
llvmlistbot at llvm.org
Mon Jan 10 08:50:21 PST 2022
Author: William S. Moses
Date: 2022-01-10T11:50:17-05:00
New Revision: a02af37560ff5aa22dcef5735ef25eaf58eaaf64
URL: https://github.com/llvm/llvm-project/commit/a02af37560ff5aa22dcef5735ef25eaf58eaaf64
DIFF: https://github.com/llvm/llvm-project/commit/a02af37560ff5aa22dcef5735ef25eaf58eaaf64.diff
LOG: [MLIR] Generalize select to arithmetic canonicalization
Given a select whose result is an i1, we can eliminate the conditional in the select completely by adding a few arithmetic operations.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D116839
Added:
Modified:
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Standard/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index a1047a58ce2c5..ed6698ba281c7 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -813,29 +813,31 @@ static LogicalResult verify(ReturnOp op) {
// SelectOp
//===----------------------------------------------------------------------===//
-// Transforms a select to a not, where relevant.
+// Transforms a select of a boolean to arithmetic operations
//
-// select %arg, %false, %true
+// select %arg, %x, %y : i1
//
// becomes
//
-// xor %arg, %true
-struct SelectToNot : public OpRewritePattern<SelectOp> {
+// and(%arg, %x) or and(!%arg, %y)
+struct SelectI1Simplify : public OpRewritePattern<SelectOp> {
using OpRewritePattern<SelectOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SelectOp op,
PatternRewriter &rewriter) const override {
- if (!matchPattern(op.getTrueValue(), m_Zero()))
- return failure();
-
- if (!matchPattern(op.getFalseValue(), m_One()))
- return failure();
-
if (!op.getType().isInteger(1))
return failure();
- rewriter.replaceOpWithNewOp<arith::XOrIOp>(op, op.getCondition(),
- op.getFalseValue());
+ Value falseConstant =
+ rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
+ Value notCondition = rewriter.create<arith::XOrIOp>(
+ op.getLoc(), op.getCondition(), falseConstant);
+
+ Value trueVal = rewriter.create<arith::AndIOp>(
+ op.getLoc(), op.getCondition(), op.getTrueValue());
+ Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
+ op.getFalseValue());
+ rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
return success();
}
};
@@ -876,7 +878,7 @@ struct SelectToExtUI : public OpRewritePattern<SelectOp> {
void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.insert<SelectToNot, SelectToExtUI>(context);
+ results.insert<SelectI1Simplify, SelectToExtUI>(context);
}
OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index b44f78f96d070..67d95ce0d194c 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -88,10 +88,23 @@ func @branchCondProp(%arg0: i1) {
// CHECK-LABEL: @selToNot
// CHECK: %[[trueval:.+]] = arith.constant true
-// CHECK: %{{.+}} = arith.xori %arg0, %[[trueval]] : i1
+// CHECK: %[[res:.+]] = arith.xori %arg0, %[[trueval]] : i1
+// CHECK: return %[[res]]
func @selToNot(%arg0: i1) -> i1 {
%true = arith.constant true
%false = arith.constant false
%res = select %arg0, %false, %true : i1
return %res : i1
}
+
+// CHECK-LABEL: @selToArith
+// CHECK-NEXT: %[[trueval:.+]] = arith.constant true
+// CHECK-NEXT: %[[notcmp:.+]] = arith.xori %arg0, %[[trueval]] : i1
+// CHECK-NEXT: %[[condtrue:.+]] = arith.andi %arg0, %arg1 : i1
+// CHECK-NEXT: %[[condfalse:.+]] = arith.andi %[[notcmp]], %arg2 : i1
+// CHECK-NEXT: %[[res:.+]] = arith.ori %[[condtrue]], %[[condfalse]] : i1
+// CHECK: return %[[res]]
+func @selToArith(%arg0: i1, %arg1 : i1, %arg2 : i1) -> i1 {
+ %res = select %arg0, %arg1, %arg2 : i1
+ return %res : i1
+}
More information about the Mlir-commits
mailing list