[Mlir-commits] [mlir] [MLIR][Vector] Move scalable dims check before IR creation in ScanToArithOps (PR #188954)

Mehdi Amini llvmlistbot at llvm.org
Fri Mar 27 03:27:50 PDT 2026


https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/188954

ScanToArithOps::matchAndRewrite created the result arith.constant before checking whether the reduction dimension is scalable. When the dimension was scalable, the pattern returned notifyMatchFailure() after IR was already modified, violating MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS.

Fix: move the reductionScalableDims check to before the arith::ConstantOp creation.

Assisted-by: Claude Code
Fix a failure present with MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=ON.

>From fc55110785a0af5978c45c4fa8d3da5ef4ea4163 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 26 Mar 2026 15:57:43 -0700
Subject: [PATCH] [MLIR][Vector] Move scalable dims check before IR creation in
 ScanToArithOps

ScanToArithOps::matchAndRewrite created the result arith.constant before
checking whether the reduction dimension is scalable. When the dimension was
scalable, the pattern returned notifyMatchFailure() after IR was already
modified, violating MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS.

Fix: move the reductionScalableDims check to before the arith::ConstantOp
creation.

Assisted-by: Claude Code
Fix a failure present with MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=ON.
---
 mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
index 1af552362a26a..f3e239bedb7f7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
@@ -111,9 +111,6 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
     if (!isValidKind(isInt, scanOp.getKind()))
       return failure();
 
-    VectorType resType = destType;
-    Value result = arith::ConstantOp::create(rewriter, loc, resType,
-                                             rewriter.getZeroAttr(resType));
     int64_t reductionDim = scanOp.getReductionDim();
     bool inclusive = scanOp.getInclusive();
     int64_t destRank = destType.getRank();
@@ -123,10 +120,16 @@ struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
     SmallVector<int64_t> reductionShape(destShape);
     SmallVector<bool> reductionScalableDims(destType.getScalableDims());
 
+    // Check before creating any IR so that returning failure() does not
+    // violate the pattern API contract.
     if (reductionScalableDims[reductionDim])
       return rewriter.notifyMatchFailure(
           scanOp, "Trying to reduce scalable dimension - not yet supported!");
 
+    VectorType resType = destType;
+    Value result = arith::ConstantOp::create(rewriter, loc, resType,
+                                             rewriter.getZeroAttr(resType));
+
     // The reduction dimension, after reducing, becomes 1. It's a fixed-width
     // dimension - no need to touch the scalability flag.
     reductionShape[reductionDim] = 1;



More information about the Mlir-commits mailing list