[Mlir-commits] [mlir] [Vector] Add canonicalization for select(pred, true, false) -> broadcast(pred) (PR #147934)
Jakub Kuderski
llvmlistbot at llvm.org
Thu Jul 10 07:54:40 PDT 2025
================
@@ -2913,13 +2913,74 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
return success();
}
};
+
+/// true: vector
+/// false: vector
+/// pred: i1
+///
+/// select(pred, true, false) -> broadcast(pred)
+/// select(pred, false, true) -> broadcast(not(pred))
+///
+/// Ideally, this would be a canonicalization pattern on arith::SelectOp, but
+/// we cannot have arith depending on vector. Also, it would implicitly force
+/// users only using arith and vector dialect to use vector dialect. Instead,
+/// this canonicalization only runs if vector::BroadcastOp was a registered
+/// operation.
+struct FoldI1SelectToBroadcast : public OpRewritePattern<arith::SelectOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arith::SelectOp selectOp,
+ PatternRewriter &rewriter) const override {
+ auto vecType = dyn_cast<VectorType>(selectOp.getType());
+ if (!vecType || !vecType.getElementType().isInteger(1))
+ return failure();
+
+ // Vector conditionals do not need broadcast and are already handled by
+ // the arith.select folder.
+ Value pred = selectOp.getCondition();
+ if (isa<VectorType>(pred.getType()))
+ return failure();
+
+ std::optional<int64_t> trueInt =
+ getConstantIntValue(selectOp.getTrueValue());
+ std::optional<int64_t> falseInt =
+ getConstantIntValue(selectOp.getFalseValue());
+ if (!trueInt || !falseInt)
+ return failure();
+
+ // Redundant selects are already handled by arith.select canonicalizations.
+ if (trueInt.value() == falseInt.value()) {
+ return failure();
+ }
+
+ // The only remaining possibilities are:
+ //
+ // select(pred, true, false)
+ // select(pred, false, true)
+
+ // select(pred, false, true) -> select(not(pred), true, false)
+ if (trueInt.value() == 0) {
+ Value one = rewriter.create<arith::ConstantIntOp>(
+ selectOp.getLoc(), /*value=*/1, /*width=*/1);
+ pred = rewriter.create<arith::XOrIOp>(selectOp.getLoc(), pred, one);
+ }
+
+ /// select(pred, true, false) -> broadcast(pred)
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ selectOp, vecType.clone(rewriter.getI1Type()), pred);
+ return success();
+
+ return failure();
----------------
kuhar wrote:
dead return
```suggestion
```
https://github.com/llvm/llvm-project/pull/147934
More information about the Mlir-commits
mailing list