[Mlir-commits] [mlir] 15a08cf - [mlir][Vector] Fold selects of single-element i1 vectors
Diego Caballero
llvmlistbot at llvm.org
Wed Aug 9 12:20:15 PDT 2023
Author: Diego Caballero
Date: 2023-08-09T18:57:36Z
New Revision: 15a08cf27cd48071ab32fcac584419e1dc2174c1
URL: https://github.com/llvm/llvm-project/commit/15a08cf27cd48071ab32fcac584419e1dc2174c1
DIFF: https://github.com/llvm/llvm-project/commit/15a08cf27cd48071ab32fcac584419e1dc2174c1.diff
LOG: [mlir][Vector] Fold selects of single-element i1 vectors
This patch adds a folding to select operation between an all-true and all-false vector.
For now, only single element vectors (i.e., vector<1xi1>) are supported. Multi-element
cases are caught by InstCombine.
Reviewed By: awarzynski
Differential Revision: https://reviews.llvm.org/D154682
Added:
mlir/test/Dialect/Vector/vector-materialize-mask.mlir
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index a6177641dc6b43..207df69929c1c9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1067,6 +1067,67 @@ class VectorCreateMaskOpConversion
const bool force32BitVectorIndices;
};
+/// Returns true if all the `i1` elements of `constantOp` are set to `value`.
+static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) {
+ auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
+ // TODO: Support non-dense constant.
+ if (!denseAttr)
+ return false;
+
+ assert(denseAttr.getElementType().isInteger(1) && "Unexpected type");
+ return denseAttr.isSplat() && denseAttr.getSplatValue<bool>() == value;
+}
+
+/// Folds a select operation between an all-true and all-false vector. For now,
+/// only single element vectors (i.e., vector<1xi1>) are supported. That is:
+///
+/// %true = arith.constant dense<true> : vector<1xi1>
+/// %false = arith.constant dense<false> : vector<1xi1>
+/// %result = arith.select %cond, %true, %false : i1, vector<1xi1>
+/// =>
+/// %result = vector.broadcast %cond : i1 to vector<1xi1>
+///
+/// InstCombine seems to handle vectors with multiple elements but not the
+/// single element ones.
+struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
+ using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arith::SelectOp selectOp,
+ PatternRewriter &rewriter) const override {
+ auto vecType = dyn_cast<VectorType>(selectOp.getType());
+ if (!vecType || !vecType.getElementType().isInteger(1))
+ return failure();
+
+ // Only scalar conditions can be folded.
+ Value cond = selectOp.getCondition();
+ if (isa<VectorType>(cond.getType()))
+ return failure();
+
+ // TODO: Support n-D and scalable vectors.
+ if (vecType.getRank() != 1 || vecType.isScalable())
+ return failure();
+
+ // TODO: Support vectors with multiple elements.
+ if (vecType.getShape()[0] != 1)
+ return failure();
+
+ auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
+ if (!trueConst || !allI1ConstantValuesSetTo(trueConst, true))
+ return failure();
+
+ auto falseConst =
+ selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
+ if (!falseConst || !allI1ConstantValuesSetTo(falseConst, false))
+ return failure();
+
+ // Replace select with its condition broadcasted to single element vector.
+ auto elemType = rewriter.getIntegerType(vecType.getNumElements());
+ auto bcastType = VectorType::get(/*shape=*/{1}, elemType);
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(selectOp, bcastType, cond);
+ return success();
+ }
+};
+
// Drop inner most contiguous unit dimensions from transfer_read operand.
class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern::OpRewritePattern;
@@ -1322,6 +1383,7 @@ void mlir::vector::populateVectorMaskMaterializationPatterns(
MaterializeTransferMask<vector::TransferReadOp>,
MaterializeTransferMask<vector::TransferWriteOp>>(
patterns.getContext(), force32BitVectorIndices, benefit);
+ patterns.add<FoldI1Select>(patterns.getContext(), benefit);
}
void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
diff --git a/mlir/test/Dialect/Vector/vector-materialize-mask.mlir b/mlir/test/Dialect/Vector/vector-materialize-mask.mlir
new file mode 100644
index 00000000000000..3d3d643168cdb4
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-materialize-mask.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s
+
+func.func @select_single_i1_vector(%cond : i1) -> vector<1xi1> {
+ %true = arith.constant dense<true> : vector<1xi1>
+ %false = arith.constant dense<false> : vector<1xi1>
+ %select = arith.select %cond, %true, %false : i1, vector<1xi1>
+ return %select : vector<1xi1>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%func_op: !transform.op<"func.func">):
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.vector.materialize_masks
+ } : !transform.op<"func.func">
+}
+
+// CHECK-LABEL: func @select_single_i1_vector
+// CHECK-SAME: %[[COND:.*]]: i1
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[COND]] : i1 to vector<1xi1>
+// CHECK: return %[[BCAST]] : vector<1xi1>
More information about the Mlir-commits
mailing list