[Mlir-commits] [mlir] [mlir][vector] Fix emulation of "narrow" type `vector.store` (PR #133231)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Thu Mar 27 05:28:03 PDT 2025
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/133231
>From c210c2db93a5730b8d7c8d394f4445e6f0ca1555 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 27 Mar 2025 10:36:35 +0000
Subject: [PATCH] [mlir][vector] Fix emulation of "narrow" type `vector.store`
Below are two examples of "narrow" `vector.stores`. The first example
does not require partial stores and hence no RMW stores. This is
currently emulated correctly.
```
func.func @example_1(%arg0: vector<4xi2>) {
%0 = memref.alloc() : memref<13xi2>
%c4 = arith.constant 4 : index
vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2>
return
}
```
The second example below does require a partial store (due to the
offset) and hence a RMW store.
```
func.func @example_2(%arg0: vector<4xi2>) {
%0 = memref.alloc() : memref<13xi2>
%c3 = arith.constant 3 : index
vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2>
return
}
```
This is currently incorrectly emulated as a single "full" store (note
that the offset is incorrect):
```
func.func @example_2(%arg0: vector<4xi2>) {
%alloc = memref.alloc() : memref<4xi8>
%0 = vector.bitcast %arg0 : vector<4xi2> to vector<1xi8>
%c0 = arith.constant 0 : index
vector.store %0, %alloc[%c0] : memref<4xi8>, vector<1xi8>
return
}
```
This PR fixes this issue. Additional comments are added to clarify the
current logic.
---
.../Transforms/VectorEmulateNarrowType.cpp | 45 +++++++++---
.../vector-emulate-narrow-type-unaligned.mlir | 69 +++++++++++++++++++
2 files changed, 103 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 5debebd3218ed..0457ae29f1c34 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -591,12 +591,12 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// vector<4xi8>
auto origElements = valueToStore.getType().getNumElements();
- // Note, per-element-alignment was already verified above.
- bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
+ // FIXME: ATM, we do not test cases where offsets, sizes, or strides are
+ // non-zero. As such, this is not needed.
OpFoldResult linearizedIndices;
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
@@ -608,8 +608,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
getAsOpFoldResult(adaptor.getIndices()));
std::optional<int64_t> foldedNumFrontPadElems =
- isDivisibleInSize ? 0
- : getConstantIntValue(linearizedInfo.intraDataOffset);
+ getConstantIntValue(linearizedInfo.intraDataOffset);
if (!foldedNumFrontPadElems) {
return rewriter.notifyMatchFailure(
@@ -619,15 +618,39 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto memrefBase = cast<MemRefValue>(adaptor.getBase());
- // Conditions when atomic RMWs are not needed:
+ // RMWs are not needed when:
+ // * no _partial_ stores are required.
+ // A partial store is defined as a store in which only a part of the
+ // container element is overwritten, e.g.
+ //
+ // Dest before (8 bits)
+ // +----------+
+ // | 11000000 |
+ // +----------+
+ //
+ // Dest after storing 0xF at offset 4 (in bits)
+ // +----------+
+ // | 11001111 |
+ // +----------+
+ //
+ // At a higher level, this translats to:
// 1. The source vector size (in bits) is a multiple of byte size.
- // 2. The address of the store is aligned to the emulated width boundary.
+ // 2. The address of the store is aligned to the container type width
+ // boundary.
+ //
+ // EXAMPLE 1:
+ // Requires partial store:
+ // vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2>
//
- // For example, to store a vector<4xi2> to <13xi2> at offset 4, does not
- // need unaligned emulation because the store address is aligned and the
- // source is a whole byte.
- bool emulationRequiresPartialStores =
- !isDivisibleInSize || *foldedNumFrontPadElems != 0;
+ // EXAMPLE 2:
+ // Does not require a partial store:
+ // vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2>
+ //
+ // TODO: Take linearizedInfo.linearizedOffset into account. This is
+ // currently not needed/used/exercised as all our tests set offset to 0.
+ bool emulationRequiresPartialStores = *foldedNumFrontPadElems != 0;
+ emulationRequiresPartialStores = true;
+
if (!emulationRequiresPartialStores) {
// Basic case: storing full bytes.
auto numElements = origElements / emulatedPerContainerElem;
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 6fc974200c6f3..40961349a4a62 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -361,6 +361,75 @@ func.func @vector_maskedload_i2_constant_mask_unaligned(%passthru: vector<5xi2>)
/// vector.store
///----------------------------------------------------------------------------------------
+// -----
+
+// Most basic example to demonstrate where partial stores are not needed.
+
+func.func @vector_store_i2_const_index_no_partial_store(%arg0: vector<4xi2>) {
+ %0 = memref.alloc() : memref<13xi2>
+ %c4 = arith.constant 4 : index
+ vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2>
+ return
+}
+// CHECK-LABEL: func.func @vector_store_i2_const_index_no_partial_store(
+// CHECK-SAME: %[[ARG_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: vector<4xi2>) {
+// CHECK-NOT: memref.generic_atomic_rmw
+// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xi8>
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[IDX:.*]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[UPCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4xi2> to vector<1xi8>
+// CHECK: vector.store %[[UPCAST]], %[[ALLOC]]{{\[}}%[[IDX]]] : memref<4xi8>, vector<1xi8>
+
+// -----
+
+// Small modification of the example above to demonstrate where partial stores
+// are needed.
+
+func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<4xi2>) {
+ %0 = memref.alloc() : memref<13xi2>
+ %c3 = arith.constant 3 : index
+ vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2>
+ return
+}
+
+// CHECK-LABEL: func.func @vector_store_i2_const_index_two_partial_stores(
+// CHECK-SAME: %[[ARG_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: vector<4xi2>) {
+// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<4xi8>
+
+// First atomic RMW:
+// CHECK: %[[IDX_1:.*]] = arith.constant 0 : index
+// CHECK: %[[MASK_1:.*]] = arith.constant dense<[false, false, false, true]> : vector<4xi1>
+// CHECK: %[[INIT:.*]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[SLICE_1:.*]] = vector.extract_strided_slice %[[ARG_0]] {offsets = [0], sizes = [1], strides = [1]} : vector<4xi2> to vector<1xi2>
+// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[SLICE_1]], %[[INIT]] {offsets = [3], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: memref.generic_atomic_rmw %[[VAL_1]]{{\[}}%[[IDX_1]]] : memref<4xi8> {
+// CHECK: ^bb0(%[[VAL_8:.*]]: i8):
+// CHECK: %[[VAL_9:.*]] = vector.from_elements %[[VAL_8]] : vector<1xi8>
+// CHECK: %[[DOWNCAST_1:.*]] = vector.bitcast %[[VAL_9]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT_1:.*]] = arith.select %[[MASK_1]], %[[V1]], %[[DOWNCAST_1]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[UPCAST_1:.*]] = vector.bitcast %[[SELECT_1]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[RES_1:.*]] = vector.extract %[[UPCAST_1]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[RES_1]] : i8
+// CHECK: }
+
+// Second atomic RMW:
+// CHECK: %[[VAL_14:.*]] = arith.constant 1 : index
+// CHECK: %[[IDX_2:.*]] = arith.addi %[[IDX_1]], %[[VAL_14]] : index
+// CHECK: %[[VAL_16:.*]] = vector.extract_strided_slice %[[ARG_0]] {offsets = [1], sizes = [3], strides = [1]} : vector<4xi2> to vector<3xi2>
+// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[VAL_16]], %[[INIT]] {offsets = [0], strides = [1]} : vector<3xi2> into vector<4xi2>
+// CHECK: %[[MASK_2:.*]] = arith.constant dense<[true, true, true, false]> : vector<4xi1>
+// CHECK: memref.generic_atomic_rmw %[[VAL_1]]{{\[}}%[[IDX_2]]] : memref<4xi8> {
+// CHECK: ^bb0(%[[VAL_20:.*]]: i8):
+// CHECK: %[[VAL_21:.*]] = vector.from_elements %[[VAL_20]] : vector<1xi8>
+// CHECK: %[[DONWCAST_2:.*]] = vector.bitcast %[[VAL_21]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT_2:.*]] = arith.select %[[MASK_2]], %[[V2]], %[[DONWCAST_2]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[UPCAST_2:.*]] = vector.bitcast %[[SELECT_2]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[RES_2:.*]] = vector.extract %[[UPCAST_2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[RES_2]] : i8
+// CHECK: }
+
+// -----
+
func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) {
%src = memref.alloc() : memref<3x3xi2>
%c0 = arith.constant 0 : index
More information about the Mlir-commits
mailing list