[Mlir-commits] [mlir] [MLIR][Vector] Move scalable dims check before IR creation in BroadcastOpLowering (PR #188953)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 27 03:27:53 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Mehdi Amini (joker-eph)
<details>
<summary>Changes</summary>
BroadcastOpLowering::matchAndRewrite created ub::PoisonOp before checking whether a scalable outer dimension prevents the stretch-not-at-start case. When that check triggered, the pattern returned failure() after IR was already modified, violating MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS.
Fix: move the scalable dimension check to before the PoisonOp creation.
Assisted-by: Claude Code
Fix a failure present with MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=ON.
---
Full diff: https://github.com/llvm/llvm-project/pull/188953.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp (+10-4)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index 61d9357e19bb4..66dd7c8f36e6b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -117,6 +117,16 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
VectorType resType =
VectorType::get(dstType.getShape().drop_front(), eltType,
dstType.getScalableDims().drop_front());
+
+ // For "stretch not at start" with a scalable outer dimension we would need
+ // to emit an scf.for loop, which is not yet supported. Check before
+ // creating any IR so that returning failure() does not violate the pattern
+ // API contract.
+ if (m != 0 && dstType.getScalableDims()[0]) {
+ // TODO: For scalable vectors we should emit an scf.for loop.
+ return failure();
+ }
+
Value result = ub::PoisonOp::create(rewriter, loc, dstType);
if (m == 0) {
// Stetch at start.
@@ -126,10 +136,6 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
result = vector::InsertOp::create(rewriter, loc, bcst, result, d);
} else {
// Stetch not at start.
- if (dstType.getScalableDims()[0]) {
- // TODO: For scalable vectors we should emit an scf.for loop.
- return failure();
- }
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(), d);
Value bcst = vector::BroadcastOp::create(rewriter, loc, resType, ext);
``````````
</details>
https://github.com/llvm/llvm-project/pull/188953
More information about the Mlir-commits
mailing list