[Mlir-commits] [mlir] [Vector] Add folder for select(pred, true, false) -> broadcast(pred) (PR #147934)

Kunwar Grover llvmlistbot at llvm.org
Thu Jul 10 03:26:05 PDT 2025


https://github.com/Groverkss created https://github.com/llvm/llvm-project/pull/147934

None

>From bd6d20d16be7d925cf08c3f292a383f0389d4d8f Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Thu, 10 Jul 2025 11:24:57 +0100
Subject: [PATCH] [Vector] Add folder for select(pred, true, false) ->
 broadcast(pred)

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 63 +++++++++++++++++++++-
 mlir/test/Dialect/Vector/canonicalize.mlir | 32 +++++++++++
 2 files changed, 94 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1fb8c7a928e06..39c8191e8451a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2913,13 +2913,74 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
     return success();
   }
 };
+
+/// true: vector
+/// false: vector
+/// pred: i1
+///
+/// select(pred, true, false) -> broadcast(pred)
+/// select(pred, false, true) -> broadcast(not(pred))
+///
+/// Ideally, this would be a canonicalization pattern on arith::SelectOp, but
+/// we cannot have arith depending on vector. Also, it would implicitly force
+/// users only using arith and vector dialect to use vector dialect. Instead,
+/// this canonicalization only runs if vector::BroadcastOp was a registered
+/// operation.
+struct FoldI1SelectToBroadcast : public OpRewritePattern<arith::SelectOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arith::SelectOp selectOp,
+                                PatternRewriter &rewriter) const override {
+    auto vecType = dyn_cast<VectorType>(selectOp.getType());
+    if (!vecType || !vecType.getElementType().isInteger(1))
+      return failure();
+
+    // Vector conditionals do not need broadcast and are already handled by
+    // the arith.select folder.
+    Value pred = selectOp.getCondition();
+    if (isa<VectorType>(pred.getType()))
+      return failure();
+
+    std::optional<int64_t> trueInt =
+        getConstantIntValue(selectOp.getTrueValue());
+    std::optional<int64_t> falseInt =
+        getConstantIntValue(selectOp.getFalseValue());
+    if (!trueInt || !falseInt)
+      return failure();
+
+    // Redundant selects are already handled by arith.select canonicalizations.
+    if (trueInt.value() == falseInt.value()) {
+      return failure();
+    }
+
+    // The only remaining possibilities are:
+    //
+    // select(pred, true, false)
+    // select(pred, false, true)
+
+    // select(pred, false, true) -> select(not(pred), true, false)
+    if (trueInt.value() == 0) {
+      Value one = rewriter.create<arith::ConstantIntOp>(
+          selectOp.getLoc(), /*value=*/1, /*width=*/1);
+      pred = rewriter.create<arith::XOrIOp>(selectOp.getLoc(), pred, one);
+    }
+
+    /// select(pred, true, false) -> broadcast(pred)
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+        selectOp, vecType.clone(rewriter.getI1Type()), pred);
+    return success();
+
+    return failure();
+  }
+};
+
 } // namespace
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
   // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
   // calling `populateCastAwayVectorLeadingOneDimPatterns`
-  results.add<BroadcastFolder>(context);
+  results.add<BroadcastFolder, FoldI1SelectToBroadcast>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0282e9cac5e02..5924e7ea856c4 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1057,6 +1057,38 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>)
 
 // -----
 
+// CHECK-LABEL: func.func @canonicalize_i1_select_to_broadcast
+// CHECK-SAME: (%[[PRED:.+]]: i1)
+// CHECK: vector.broadcast %[[PRED]] : i1 to vector<4xi1>
+func.func @canonicalize_i1_select_to_broadcast(%pred: i1) -> vector<4xi1> {
+  %true = arith.constant dense<true> : vector<4x4xi1>
+  %false = arith.constant dense<false> : vector<4x4xi1>
+  %selected = arith.select %pred, %true, %false : vector<4x4xi1>
+  // The select -> broadcast pattern only loads if vector dialect was loaded.
+  // Force loading vector dialect by adding a vector operation.
+  %vec = vector.extract %selected[0] : vector<4xi1> from vector<4x4xi1>
+  return %vec : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @canonicalize_i1_select_to_not_broadcast
+// CHECK-SAME: (%[[PRED:.+]]: i1)
+// CHECK: %[[TRUE:.+]] = arith.constant true
+// CHECK: %[[NOT:.+]] = arith.xori %[[PRED]], %[[TRUE]] : i1
+// CHECK: vector.broadcast %[[NOT]] : i1 to vector<4xi1>
+func.func @canonicalize_i1_select_to_not_broadcast(%pred: i1) -> vector<4xi1> {
+  %true = arith.constant dense<true> : vector<4x4xi1>
+  %false = arith.constant dense<false> : vector<4x4xi1>
+  %selected = arith.select %pred, %false, %true : vector<4x4xi1>
+  // The select -> broadcast pattern only loads if vector dialect was loaded.
+  // Force loading vector dialect by adding a vector operation.
+  %vec = vector.extract %selected[0] : vector<4xi1> from vector<4x4xi1>
+  return %vec : vector<4xi1>
+}
+
+// -----
+
 // CHECK-LABEL: fold_vector_transfer_masks
 func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index



More information about the Mlir-commits mailing list