[Mlir-commits] [mlir] [mlir][vector] Fix incorrect byte-alignment assumption in ConvertVectorStore (PR #189235)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Mar 29 06:02:23 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Mehdi Amini (joker-eph)
<details>
<summary>Changes</summary>
When `ConvertVectorStore` emits the narrow-type emulation for a `vector.store` into a 2-D memref, it previously assumed that if the trailing dimension of the memref exactly matches the vector size (`trailingDimsMatch`), then the last-dimension index must be zero and no sub-byte alignment adjustment is needed. This assumption is wrong: a valid store such as
vector.store %v, %src[%c0, %c1] : memref<3x4xi2>, vector<4xi2>
has a non-zero column index (%c1 == 1) even though trailingDim (4) equals the vector size (4). The incorrect shortcut caused the pattern to fall into the "aligned" path and emit a plain bitcast + store at byte offset 0, silently dropping elements [1], [2], [3] of the first byte and overwriting the wrong memory.
Fix: prefer `linearizedInfo.intraDataOffset` (which gives the exact sub-element offset for any constant-index store) and only fall back to the old `0` assumption when the indices are fully dynamic (i.e., `intraDataOffset` cannot be folded to a constant) **and** `isDivisibleInSize && trailingDimsMatch` still holds. This preserves the existing behaviour for dynamic-index stores while fixing the constant-index case.
Fixes #<!-- -->131528
Assisted-by: Claude Code
---
Full diff: https://github.com/llvm/llvm-project/pull/189235.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+17-9)
- (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir (+24)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 583cda7ac2810..60161cdf15dfd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -635,10 +635,13 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
return success();
}
- // Do the trailing dim for source and destination match? If yes, then the
- // corresponding index must be 0.
- // FIXME: There's no way to tell for dynamic shapes, so we should bail out.
- // However, that makes some tests fail, so we need to audit first.
+ // Do the trailing dim for source and destination match? If yes, and if the
+ // access indices are not all constant, then assume the last index is 0
+ // (byte-aligned). Note: for constant indices, the intraDataOffset computed
+ // below will give the exact value, so the trailingDimsMatch shortcut is
+ // not used in that case.
+ // FIXME: For dynamic indices where trailingDimsMatch, the assumption that
+ // the last index is 0 (byte-aligned) may be incorrect. See issue #131528.
auto trailingDim = op.getBase().getType().getShape().back();
bool trailingDimsMatch =
ShapedType::isDynamic(trailingDim) || trailingDim == origElements;
@@ -646,8 +649,6 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto stridedMetadata =
memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
- // FIXME: ATM, we do not test cases where offsets, sizes, or strides are
- // non-zero. As such, this is not needed.
OpFoldResult linearizedIndices;
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
@@ -658,10 +659,17 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
stridedMetadata.getConstifiedMixedStrides(),
getAsOpFoldResult(adaptor.getIndices()));
+ // Prefer the exact intraDataOffset when it can be folded (e.g. all-constant
+ // indices). Fall back to 0 only when the trailing dimension exactly matches
+ // the vector size (trailingDimsMatch), because in that case a dynamic last
+ // index implies byte-alignment (the caller is responsible for passing a
+ // valid, aligned index). If neither condition holds, bail out.
std::optional<int64_t> foldedNumFrontPadElems =
- (isDivisibleInSize && trailingDimsMatch)
- ? 0
- : getConstantIntValue(linearizedInfo.intraDataOffset);
+ getConstantIntValue(linearizedInfo.intraDataOffset);
+ if (!foldedNumFrontPadElems) {
+ if (isDivisibleInSize && trailingDimsMatch)
+ foldedNumFrontPadElems = 0;
+ }
if (!foldedNumFrontPadElems) {
return rewriter.notifyMatchFailure(
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 21f073efc49b2..bec7736b90973 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -562,3 +562,27 @@ func.func @vector_store_i2_const_index_one_partial_store(%arg0: vector<1xi2>) {
// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
+
+// -----
+
+// Regression test for https://github.com/llvm/llvm-project/issues/131528.
+// A vector.store with a non-zero constant column index on a 2D memref where the
+// trailing dimension matches the vector size must NOT be treated as
+// byte-aligned. Instead it must emit two partial (RMW) stores since the
+// 4-element i2 vector starting at column 1 crosses a byte boundary.
+func.func @vector_store_i2_2d_const_nonzero_col(%arg0: vector<4xi2>) {
+ %src = memref.alloc() : memref<3x4xi2>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ vector.store %arg0, %src[%c0, %c1] : memref<3x4xi2>, vector<4xi2>
+ return
+}
+
+// CHECK-LABEL: func @vector_store_i2_2d_const_nonzero_col(
+// CHECK-SAME: %[[ARG0:.+]]: vector<4xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// Emits two partial atomic RMWs: one for byte 0 (elements at positions [1..3])
+// and one for byte 1 (element at position [0]).
+// CHECK: memref.generic_atomic_rmw %[[ALLOC]][%[[C0]]]
+// CHECK: memref.generic_atomic_rmw %[[ALLOC]][{{.+}}]
``````````
</details>
https://github.com/llvm/llvm-project/pull/189235
More information about the Mlir-commits
mailing list