[Mlir-commits] [mlir] [mlir][vector] Fix incorrect byte-alignment assumption in ConvertVectorStore (PR #189235)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Mar 29 06:02:23 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Mehdi Amini (joker-eph)

<details>
<summary>Changes</summary>

When `ConvertVectorStore` emits the narrow-type emulation for a `vector.store` into a 2-D memref, it previously assumed that if the trailing dimension of the memref exactly matches the vector size (`trailingDimsMatch`), then the last-dimension index must be zero and no sub-byte alignment adjustment is needed. This assumption is wrong: a valid store such as

  vector.store %v, %src[%c0, %c1] : memref<3x4xi2>, vector<4xi2>

has a non-zero column index (%c1 == 1) even though trailingDim (4) equals the vector size (4). The incorrect shortcut caused the pattern to fall into the "aligned" path and emit a plain bitcast + store at byte offset 0, silently dropping elements [1], [2], [3] of the first byte and overwriting the wrong memory.

Fix: prefer `linearizedInfo.intraDataOffset` (which gives the exact sub-element offset for any constant-index store) and only fall back to the old `0` assumption when the indices are fully dynamic (i.e., `intraDataOffset` cannot be folded to a constant) **and** `isDivisibleInSize && trailingDimsMatch` still holds. This preserves the existing behaviour for dynamic-index stores while fixing the constant-index case.

Fixes #<!-- -->131528

Assisted-by: Claude Code

---
Full diff: https://github.com/llvm/llvm-project/pull/189235.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+17-9) 
- (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir (+24) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 583cda7ac2810..60161cdf15dfd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -635,10 +635,13 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
       return success();
     }
 
-    // Do the trailing dim for source and destination match? If yes, then the
-    // corresponding index must be 0.
-    // FIXME: There's no way to tell for dynamic shapes, so we should bail out.
-    // However, that makes some tests fail, so we need to audit first.
+    // Do the trailing dim for source and destination match? If yes, and if the
+    // access indices are not all constant, then assume the last index is 0
+    // (byte-aligned). Note: for constant indices, the intraDataOffset computed
+    // below will give the exact value, so the trailingDimsMatch shortcut is
+    // not used in that case.
+    // FIXME: For dynamic indices where trailingDimsMatch, the assumption that
+    // the last index is 0 (byte-aligned) may be incorrect. See issue #131528.
     auto trailingDim = op.getBase().getType().getShape().back();
     bool trailingDimsMatch =
         ShapedType::isDynamic(trailingDim) || trailingDim == origElements;
@@ -646,8 +649,6 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     auto stridedMetadata =
         memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
 
-    // FIXME: ATM, we do not test cases where offsets, sizes, or strides are
-    // non-zero. As such, this is not needed.
     OpFoldResult linearizedIndices;
     memref::LinearizedMemRefInfo linearizedInfo;
     std::tie(linearizedInfo, linearizedIndices) =
@@ -658,10 +659,17 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
+    // Prefer the exact intraDataOffset when it can be folded (e.g. all-constant
+    // indices). Fall back to 0 only when the trailing dimension exactly matches
+    // the vector size (trailingDimsMatch), because in that case a dynamic last
+    // index implies byte-alignment (the caller is responsible for passing a
+    // valid, aligned index). If neither condition holds, bail out.
     std::optional<int64_t> foldedNumFrontPadElems =
-        (isDivisibleInSize && trailingDimsMatch)
-            ? 0
-            : getConstantIntValue(linearizedInfo.intraDataOffset);
+        getConstantIntValue(linearizedInfo.intraDataOffset);
+    if (!foldedNumFrontPadElems) {
+      if (isDivisibleInSize && trailingDimsMatch)
+        foldedNumFrontPadElems = 0;
+    }
 
     if (!foldedNumFrontPadElems) {
       return rewriter.notifyMatchFailure(
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 21f073efc49b2..bec7736b90973 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -562,3 +562,27 @@ func.func @vector_store_i2_const_index_one_partial_store(%arg0: vector<1xi2>) {
 // CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
 // CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
 // CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
+
+// -----
+
+// Regression test for https://github.com/llvm/llvm-project/issues/131528.
+// A vector.store with a non-zero constant column index on a 2D memref where the
+// trailing dimension matches the vector size must NOT be treated as
+// byte-aligned. Instead it must emit two partial (RMW) stores since the
+// 4-element i2 vector starting at column 1 crosses a byte boundary.
+func.func @vector_store_i2_2d_const_nonzero_col(%arg0: vector<4xi2>) {
+  %src = memref.alloc() : memref<3x4xi2>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  vector.store %arg0, %src[%c0, %c1] : memref<3x4xi2>, vector<4xi2>
+  return
+}
+
+// CHECK-LABEL: func @vector_store_i2_2d_const_nonzero_col(
+// CHECK-SAME: %[[ARG0:.+]]: vector<4xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// Emits two partial atomic RMWs: one for byte 0 (elements at positions [1..3])
+// and one for byte 1 (element at position [0]).
+// CHECK: memref.generic_atomic_rmw %[[ALLOC]][%[[C0]]]
+// CHECK: memref.generic_atomic_rmw %[[ALLOC]][{{.+}}]

``````````

</details>


https://github.com/llvm/llvm-project/pull/189235


More information about the Mlir-commits mailing list