[Mlir-commits] [mlir] [Vector] Add folder for select(pred, true, false) -> broadcast(pred) (PR #147934)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 10 03:26:38 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Kunwar Grover (Groverkss)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/147934.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+62-1)
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+32)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1fb8c7a928e06..39c8191e8451a 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2913,13 +2913,74 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
return success();
}
};
+
+/// true: vector
+/// false: vector
+/// pred: i1
+///
+/// select(pred, true, false) -> broadcast(pred)
+/// select(pred, false, true) -> broadcast(not(pred))
+///
+/// Ideally, this would be a canonicalization pattern on arith::SelectOp, but
+/// we cannot have arith depending on vector. Also, it would implicitly force
+/// users only using arith and vector dialect to use vector dialect. Instead,
+/// this canonicalization only runs if vector::BroadcastOp was a registered
+/// operation.
+struct FoldI1SelectToBroadcast : public OpRewritePattern<arith::SelectOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arith::SelectOp selectOp,
+ PatternRewriter &rewriter) const override {
+ auto vecType = dyn_cast<VectorType>(selectOp.getType());
+ if (!vecType || !vecType.getElementType().isInteger(1))
+ return failure();
+
+ // Vector conditionals do not need broadcast and are already handled by
+ // the arith.select folder.
+ Value pred = selectOp.getCondition();
+ if (isa<VectorType>(pred.getType()))
+ return failure();
+
+ std::optional<int64_t> trueInt =
+ getConstantIntValue(selectOp.getTrueValue());
+ std::optional<int64_t> falseInt =
+ getConstantIntValue(selectOp.getFalseValue());
+ if (!trueInt || !falseInt)
+ return failure();
+
+ // Redundant selects are already handled by arith.select canonicalizations.
+ if (trueInt.value() == falseInt.value()) {
+ return failure();
+ }
+
+ // The only remaining possibilities are:
+ //
+ // select(pred, true, false)
+ // select(pred, false, true)
+
+ // select(pred, false, true) -> select(not(pred), true, false)
+ if (trueInt.value() == 0) {
+ Value one = rewriter.create<arith::ConstantIntOp>(
+ selectOp.getLoc(), /*value=*/1, /*width=*/1);
+ pred = rewriter.create<arith::XOrIOp>(selectOp.getLoc(), pred, one);
+ }
+
+ /// select(pred, true, false) -> broadcast(pred)
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ selectOp, vecType.clone(rewriter.getI1Type()), pred);
+ return success();
+
+ return failure();
+ }
+};
+
} // namespace
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// BroadcastToShapeCast is not a default canonicalization, it is opt-in by
// calling `populateCastAwayVectorLeadingOneDimPatterns`
- results.add<BroadcastFolder>(context);
+ results.add<BroadcastFolder, FoldI1SelectToBroadcast>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0282e9cac5e02..5924e7ea856c4 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1057,6 +1057,38 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>)
// -----
+// CHECK-LABEL: func.func @canonicalize_i1_select_to_broadcast
+// CHECK-SAME: (%[[PRED:.+]]: i1)
+// CHECK: vector.broadcast %[[PRED]] : i1 to vector<4xi1>
+func.func @canonicalize_i1_select_to_broadcast(%pred: i1) -> vector<4xi1> {
+ %true = arith.constant dense<true> : vector<4x4xi1>
+ %false = arith.constant dense<false> : vector<4x4xi1>
+ %selected = arith.select %pred, %true, %false : vector<4x4xi1>
+ // The select -> broadcast pattern only loads if vector dialect was loaded.
+ // Force loading vector dialect by adding a vector operation.
+ %vec = vector.extract %selected[0] : vector<4xi1> from vector<4x4xi1>
+ return %vec : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @canonicalize_i1_select_to_not_broadcast
+// CHECK-SAME: (%[[PRED:.+]]: i1)
+// CHECK: %[[TRUE:.+]] = arith.constant true
+// CHECK: %[[NOT:.+]] = arith.xori %[[PRED]], %[[TRUE]] : i1
+// CHECK: vector.broadcast %[[NOT]] : i1 to vector<4xi1>
+func.func @canonicalize_i1_select_to_not_broadcast(%pred: i1) -> vector<4xi1> {
+ %true = arith.constant dense<true> : vector<4x4xi1>
+ %false = arith.constant dense<false> : vector<4x4xi1>
+ %selected = arith.select %pred, %false, %true : vector<4x4xi1>
+ // The select -> broadcast pattern only loads if vector dialect was loaded.
+ // Force loading vector dialect by adding a vector operation.
+ %vec = vector.extract %selected[0] : vector<4xi1> from vector<4x4xi1>
+ return %vec : vector<4xi1>
+}
+
+// -----
+
// CHECK-LABEL: fold_vector_transfer_masks
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
``````````
</details>
https://github.com/llvm/llvm-project/pull/147934
More information about the Mlir-commits
mailing list