[Mlir-commits] [mlir] [Vector] Add canonicalization for select(pred, true, false) -> broadcast(pred) (PR #147934)

Jakub Kuderski llvmlistbot at llvm.org
Thu Jul 10 07:54:40 PDT 2025


================
@@ -2913,13 +2913,74 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
     return success();
   }
 };
+
+/// true: vector
+/// false: vector
+/// pred: i1
+///
+/// select(pred, true, false) -> broadcast(pred)
+/// select(pred, false, true) -> broadcast(not(pred))
+///
+/// Ideally, this would be a canonicalization pattern on arith::SelectOp, but
+/// we cannot have arith depending on vector. Also, it would implicitly force
+/// users only using arith and vector dialect to use vector dialect. Instead,
+/// this canonicalization only runs if vector::BroadcastOp was a registered
+/// operation.
+struct FoldI1SelectToBroadcast : public OpRewritePattern<arith::SelectOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arith::SelectOp selectOp,
+                                PatternRewriter &rewriter) const override {
+    auto vecType = dyn_cast<VectorType>(selectOp.getType());
+    if (!vecType || !vecType.getElementType().isInteger(1))
+      return failure();
+
+    // Vector conditionals do not need broadcast and are already handled by
+    // the arith.select folder.
+    Value pred = selectOp.getCondition();
+    if (isa<VectorType>(pred.getType()))
+      return failure();
+
+    std::optional<int64_t> trueInt =
+        getConstantIntValue(selectOp.getTrueValue());
+    std::optional<int64_t> falseInt =
+        getConstantIntValue(selectOp.getFalseValue());
+    if (!trueInt || !falseInt)
+      return failure();
+
+    // Redundant selects are already handled by arith.select canonicalizations.
+    if (trueInt.value() == falseInt.value()) {
+      return failure();
+    }
+
+    // The only remaining possibilities are:
+    //
+    // select(pred, true, false)
+    // select(pred, false, true)
+
+    // select(pred, false, true) -> select(not(pred), true, false)
+    if (trueInt.value() == 0) {
+      Value one = rewriter.create<arith::ConstantIntOp>(
+          selectOp.getLoc(), /*value=*/1, /*width=*/1);
+      pred = rewriter.create<arith::XOrIOp>(selectOp.getLoc(), pred, one);
+    }
+
+    /// select(pred, true, false) -> broadcast(pred)
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+        selectOp, vecType.clone(rewriter.getI1Type()), pred);
+    return success();
+
+    return failure();
----------------
kuhar wrote:

dead return
```suggestion
```

https://github.com/llvm/llvm-project/pull/147934


More information about the Mlir-commits mailing list