[Mlir-commits] [mlir] [mlir][vector] Add assumeAligned mode to vector.store narrow type emulation (PR #178565)

Han-Chung Wang llvmlistbot at llvm.org
Thu Jan 29 10:13:14 PST 2026


https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/178565

>From 95859e5e4cbf4abed8d0c210d2df43cf065b29ea Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Wed, 28 Jan 2026 14:05:29 -0800
Subject: [PATCH 1/2] [mlir][vector] Add assumeAligned mode to vector.store
 narrow type emulation

The revision adds a new `assumeAligned` mode to the emulation, so
downstream projects can use simple path when it meets the constraints.
E.g., if the offset is always aligned with container's element type, we
can skip the check of front padding sizes.

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 .../Vector/Transforms/VectorRewritePatterns.h |  9 ++-
 .../Transforms/VectorEmulateNarrowType.cpp    | 49 +++++++++++++++--
 ...mulate-narrow-type-aligned-store-only.mlir | 55 +++++++++++++++++++
 .../Dialect/MemRef/TestEmulateNarrowType.cpp  | 10 +++-
 4 files changed, 115 insertions(+), 8 deletions(-)
 create mode 100644 mlir/test/Dialect/Vector/vector-emulate-narrow-type-aligned-store-only.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 69438011d2287..bf0f450098a87 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -388,10 +388,15 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
 /// Appends patterns for emulating vector operations over narrow types with ops
 /// over wider types. The `disableAtomicRMW` indicates whether to use a normal
 /// read-modify-write sequence instead of using `memref.generic_atomic_rmw` to
-/// perform subbyte storing.
+/// perform subbyte storing. When `assumeAligned` is true, store offsets are
+/// assumed to be aligned to container element boundaries, so a store whose
+/// source vector fills whole container elements is emitted as a simple
+/// bitcast + store without checking the offset. Stores that are not divisible
+/// in size are rejected.
 void populateVectorNarrowTypeEmulationPatterns(
     const arith::NarrowTypeEmulationConverter &typeConverter,
-    RewritePatternSet &patterns, bool disableAtomicRMW = false);
+    RewritePatternSet &patterns, bool disableAtomicRMW = false,
+    bool assumeAligned = false);
 
 /// Populates patterns for both MeMref flattening and Vector narrow type
 /// emulation.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 3a3231d513369..8bcc503262630 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -510,6 +510,13 @@ namespace {
 
 // Emulate `vector.store` using a multi-byte container type.
 //
+// When `assumeAligned` is true, store offsets are assumed to be aligned to
+// container element boundaries, so a store whose source vector fills whole
+// container elements (isDivisibleInSize) is emitted as a simple bitcast +
+// store without checking the offset. Stores that are not divisible in size
+// are rejected. This is useful for downstream users that have already
+// ensured alignment.
+//
 // The container type is obtained through Op adaptor and would normally be
 // generated via `NarrowTypeEmulationConverter`.
 //
@@ -550,9 +557,10 @@ namespace {
 struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
   using Base::Base;
 
-  ConvertVectorStore(MLIRContext *context, bool disableAtomicRMW)
+  ConvertVectorStore(MLIRContext *context, bool disableAtomicRMW,
+                     bool assumeAligned)
       : OpConversionPattern<vector::StoreOp>(context),
-        disableAtomicRMW(disableAtomicRMW) {}
+        disableAtomicRMW(disableAtomicRMW), assumeAligned(assumeAligned) {}
 
   LogicalResult
   matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
@@ -595,6 +603,37 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     auto origElements = valueToStore.getType().getNumElements();
     // Note, per-element-alignment was already verified above.
     bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
+
+    // In assume-aligned mode, isDivisibleInSize alone is sufficient — the
+    // caller guarantees that store offsets are aligned to container element
+    // boundaries.
+    if (assumeAligned) {
+      if (!isDivisibleInSize)
+        return rewriter.notifyMatchFailure(
+            op, "the source vector does not fill whole container elements "
+                "(not divisible in size)");
+
+      auto stridedMetadata = memref::ExtractStridedMetadataOp::create(
+          rewriter, loc, op.getBase());
+      OpFoldResult linearizedIndices;
+      std::tie(std::ignore, linearizedIndices) =
+          memref::getLinearizedMemRefOffsetAndSize(
+              rewriter, loc, emulatedBits, containerBits,
+              stridedMetadata.getConstifiedMixedOffset(),
+              stridedMetadata.getConstifiedMixedSizes(),
+              stridedMetadata.getConstifiedMixedStrides(),
+              getAsOpFoldResult(adaptor.getIndices()));
+      auto memrefBase = cast<MemRefValue>(adaptor.getBase());
+      int numElements = origElements / emulatedPerContainerElem;
+      auto bitCast = vector::BitCastOp::create(
+          rewriter, loc, VectorType::get(numElements, containerElemTy),
+          op.getValueToStore());
+      rewriter.replaceOpWithNewOp<vector::StoreOp>(
+          op, bitCast.getResult(), memrefBase,
+          getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+      return success();
+    }
+
     // Do the trailing dim for source and destination match? If yes, then the
     // corresponding index must be 0.
     // FIXME: There's no way to tell for dynamic shapes, so we should bail out.
@@ -812,6 +851,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
 
 private:
   const bool disableAtomicRMW;
+  const bool assumeAligned;
 };
 
 //===----------------------------------------------------------------------===//
@@ -2244,7 +2284,7 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
 // The emulated type is inferred from the converted memref type.
 void vector::populateVectorNarrowTypeEmulationPatterns(
     const arith::NarrowTypeEmulationConverter &typeConverter,
-    RewritePatternSet &patterns, bool disableAtomicRMW) {
+    RewritePatternSet &patterns, bool disableAtomicRMW, bool assumeAligned) {
   // Populate `vector.*` conversion patterns.
   // TODO: #119553 support atomicity
   patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
@@ -2254,7 +2294,8 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
   // Populate `vector.*` store conversion patterns. The caller can choose
   // to avoid emitting atomic operations and reduce it to read-modify-write
   // sequence for stores if it is known there are no thread contentions.
-  patterns.insert<ConvertVectorStore>(patterns.getContext(), disableAtomicRMW);
+  patterns.insert<ConvertVectorStore>(patterns.getContext(), disableAtomicRMW,
+                                      assumeAligned);
 }
 
 void vector::populateVectorNarrowTypeRewritePatterns(
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-aligned-store-only.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-aligned-store-only.mlir
new file mode 100644
index 0000000000000..2ef568fa6a741
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-aligned-store-only.mlir
@@ -0,0 +1,55 @@
+// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8 assume-aligned=true" --cse --verify-diagnostics --split-input-file %s | FileCheck %s
+
+/// Aligned store, constant index - the source vector fills whole container
+/// elements. Produces a simple bitcast + store.
+func.func @vector_store_i4_aligned_const(%arg0: vector<8xi4>, %arg1: index, %arg2: index) {
+    %0 = memref.alloc() : memref<4x8xi4>
+    vector.store %arg0, %0[%arg1, %arg2] : memref<4x8xi4>, vector<8xi4>
+    return
+}
+//  CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+//      CHECK: func @vector_store_i4_aligned_const
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9]+]]: index
+//      CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<16xi8>
+//      CHECK:   %[[INDEX:.+]] = affine.apply #[[$MAP]]()[%[[ARG1]], %[[ARG2]]]
+//      CHECK:   %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<4xi8>
+//      CHECK:   vector.store %[[VEC_I8]], %[[ALLOC]][%[[INDEX]]] : memref<16xi8>, vector<4xi8>
+
+// -----
+
+/// Aligned store, dynamic index. The source vector (8 x i4 = 32 bits) is a
+/// whole multiple of the container element size (i8 = 8 bits), so no partial
+/// stores are needed. This holds regardless of the dynamic offset.
+func.func @vector_store_i4_aligned_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
+    %0 = memref.alloc(%arg1, %arg2) : memref<?x?xi4>
+    vector.store %arg0, %0[%arg3, %arg4] : memref<?x?xi4>, vector<8xi4>
+    return
+}
+//  CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2, s0 floordiv 2)>
+//  CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
+//      CHECK: func @vector_store_i4_aligned_dynamic
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:   %[[ARG4:[a-zA-Z0-9]+]]: index
+//      CHECK:   %[[SIZE:.+]] = affine.max #[[$MAP]]()[%[[ARG2]], %[[ARG1]]]
+//      CHECK:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
+//      CHECK:   %[[INDEX:.+]] = affine.apply #[[$MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
+//      CHECK:   %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<4xi8>
+//      CHECK:   vector.store %[[VEC_I8]], %[[ALLOC]][%[[INDEX]]] : memref<?xi8>, vector<4xi8>
+
+// -----
+
+/// The source vector does not fill whole container elements (3 x i4 != N x i8),
+/// so the aligned pattern rejects it. With aligned-store-only, no unaligned
+/// pattern is available, so legalization fails.
+func.func @vector_store_i4_not_divisible(%arg0: vector<3xi4>) {
+    %0 = memref.alloc() : memref<12xi4>
+    %c0 = arith.constant 0 : index
+    // expected-error @below {{failed to legalize operation 'vector.store' that was explicitly marked illegal}}
+    vector.store %arg0, %0[%c0] : memref<12xi4>, vector<3xi4>
+    return
+}
diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
index b5f015aff19b4..921afb8d2180a 100644
--- a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
+++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
@@ -100,8 +100,8 @@ struct TestEmulateNarrowTypePass
 
     arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
     memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
-    vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns,
-                                                      disableAtomicRMW);
+    vector::populateVectorNarrowTypeEmulationPatterns(
+        typeConverter, patterns, disableAtomicRMW, assumeAligned);
 
     if (failed(applyPartialConversion(op, target, std::move(patterns))))
       signalPassFailure();
@@ -126,6 +126,12 @@ struct TestEmulateNarrowTypePass
       llvm::cl::desc("disable atomic read-modify-write and prefer generating "
                      "normal sequence"),
       llvm::cl::init(false)};
+
+  Option<bool> assumeAligned{
+      *this, "assume-aligned",
+      llvm::cl::desc("assume store offsets are aligned to container element "
+                     "boundaries"),
+      llvm::cl::init(false)};
 };
 
 struct TestMemRefFlattenAndVectorNarrowTypeEmulationPass

>From ad23cdf148fc2129981836749b3ee7ea4d5ffaf7 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 29 Jan 2026 10:12:57 -0800
Subject: [PATCH 2/2] clang-format

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 .../lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 8bcc503262630..7fd639f1354ea 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -613,8 +613,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
             op, "the source vector does not fill whole container elements "
                 "(not divisible in size)");
 
-      auto stridedMetadata = memref::ExtractStridedMetadataOp::create(
-          rewriter, loc, op.getBase());
+      auto stridedMetadata =
+          memref::ExtractStridedMetadataOp::create(rewriter, loc, op.getBase());
       OpFoldResult linearizedIndices;
       std::tie(std::ignore, linearizedIndices) =
           memref::getLinearizedMemRefOffsetAndSize(



More information about the Mlir-commits mailing list