[Mlir-commits] [mlir] 15a08cf - [mlir][Vector] Fold selects of single-element i1 vectors

Diego Caballero llvmlistbot at llvm.org
Wed Aug 9 12:20:15 PDT 2023


Author: Diego Caballero
Date: 2023-08-09T18:57:36Z
New Revision: 15a08cf27cd48071ab32fcac584419e1dc2174c1

URL: https://github.com/llvm/llvm-project/commit/15a08cf27cd48071ab32fcac584419e1dc2174c1
DIFF: https://github.com/llvm/llvm-project/commit/15a08cf27cd48071ab32fcac584419e1dc2174c1.diff

LOG: [mlir][Vector] Fold selects of single-element i1 vectors

This patch adds a folding to select operation between an all-true and all-false vector.
For now, only single element vectors (i.e., vector<1xi1>) are supported. Multi-element
cases are caught by InstCombine.

Reviewed By: awarzynski

Differential Revision: https://reviews.llvm.org/D154682

Added: 
    mlir/test/Dialect/Vector/vector-materialize-mask.mlir

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index a6177641dc6b43..207df69929c1c9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1067,6 +1067,67 @@ class VectorCreateMaskOpConversion
   const bool force32BitVectorIndices;
 };
 
+/// Returns true if all the `i1` elements of `constantOp` are set to `value`.
+static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) {
+  auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
+  // TODO: Support non-dense constant.
+  if (!denseAttr)
+    return false;
+
+  assert(denseAttr.getElementType().isInteger(1) && "Unexpected type");
+  return denseAttr.isSplat() && denseAttr.getSplatValue<bool>() == value;
+}
+
+/// Folds a select operation between an all-true and all-false vector. For now,
+/// only single element vectors (i.e., vector<1xi1>) are supported. That is:
+///
+///   %true = arith.constant dense<true> : vector<1xi1>
+///   %false = arith.constant dense<false> : vector<1xi1>
+///   %result = arith.select %cond, %true, %false : i1, vector<1xi1>
+///   =>
+///   %result = vector.broadcast %cond : i1 to vector<1xi1>
+///
+/// InstCombine seems to handle vectors with multiple elements but not the
+/// single element ones.
+struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
+  using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arith::SelectOp selectOp,
+                                PatternRewriter &rewriter) const override {
+    auto vecType = dyn_cast<VectorType>(selectOp.getType());
+    if (!vecType || !vecType.getElementType().isInteger(1))
+      return failure();
+
+    // Only scalar conditions can be folded.
+    Value cond = selectOp.getCondition();
+    if (isa<VectorType>(cond.getType()))
+      return failure();
+
+    // TODO: Support n-D and scalable vectors.
+    if (vecType.getRank() != 1 || vecType.isScalable())
+      return failure();
+
+    // TODO: Support vectors with multiple elements.
+    if (vecType.getShape()[0] != 1)
+      return failure();
+
+    auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
+    if (!trueConst || !allI1ConstantValuesSetTo(trueConst, true))
+      return failure();
+
+    auto falseConst =
+        selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
+    if (!falseConst || !allI1ConstantValuesSetTo(falseConst, false))
+      return failure();
+
+    // Replace select with its condition broadcasted to single element vector.
+    auto elemType = rewriter.getIntegerType(vecType.getNumElements());
+    auto bcastType = VectorType::get(/*shape=*/{1}, elemType);
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(selectOp, bcastType, cond);
+    return success();
+  }
+};
+
 // Drop inner most contiguous unit dimensions from transfer_read operand.
 class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -1322,6 +1383,7 @@ void mlir::vector::populateVectorMaskMaterializationPatterns(
                MaterializeTransferMask<vector::TransferReadOp>,
                MaterializeTransferMask<vector::TransferWriteOp>>(
       patterns.getContext(), force32BitVectorIndices, benefit);
+  patterns.add<FoldI1Select>(patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,

diff  --git a/mlir/test/Dialect/Vector/vector-materialize-mask.mlir b/mlir/test/Dialect/Vector/vector-materialize-mask.mlir
new file mode 100644
index 00000000000000..3d3d643168cdb4
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-materialize-mask.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s
+
+func.func @select_single_i1_vector(%cond : i1) -> vector<1xi1> {
+  %true = arith.constant dense<true> : vector<1xi1>
+  %false = arith.constant dense<false> : vector<1xi1>
+  %select = arith.select %cond, %true, %false : i1, vector<1xi1>
+  return %select : vector<1xi1>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%func_op: !transform.op<"func.func">):
+  transform.apply_patterns to %func_op {
+    transform.apply_patterns.vector.materialize_masks
+  } : !transform.op<"func.func">
+}
+
+// CHECK-LABEL: func @select_single_i1_vector
+// CHECK-SAME: %[[COND:.*]]: i1
+// CHECK:      %[[BCAST:.*]] = vector.broadcast %[[COND]] : i1 to vector<1xi1>
+// CHECK:      return %[[BCAST]] : vector<1xi1>


        


More information about the Mlir-commits mailing list