[Mlir-commits] [mlir] ad948fa - [mlir][vector] Document `ConvertVectorStore` + unify var names (nfc) (#126422)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Feb 15 12:16:29 PST 2025
Author: Andrzej WarzyĆski
Date: 2025-02-15T20:16:25Z
New Revision: ad948fa028bdfe1f15785aec4477f92ec681637a
URL: https://github.com/llvm/llvm-project/commit/ad948fa028bdfe1f15785aec4477f92ec681637a
DIFF: https://github.com/llvm/llvm-project/commit/ad948fa028bdfe1f15785aec4477f92ec681637a.diff
LOG: [mlir][vector] Document `ConvertVectorStore` + unify var names (nfc) (#126422)
1. Documents `ConvertVectorStore`. As the generated output is rather complex, I
have refined the comments + variable names in:
* "vector-emulate-narrow-type-unaligned-non-atomic.mlir",
to serve as reference for this pattern.
2. As a follow-on for #123527, renames `isAlignedEmulation` to `isFullyAligned`
and `numSrcElemsPerDest` to `emulatedPerContainerElem`.
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index bf1ecd7d4559c..5d8a525ac87f1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -432,7 +432,45 @@ namespace {
// ConvertVectorStore
//===----------------------------------------------------------------------===//
-// TODO: Document-me
+// Emulate `vector.store` using a multi-byte container type.
+//
+// The container type is obtained through Op adaptor and would normally be
+// generated via `NarrowTypeEmulationConverter`.
+//
+// EXAMPLE 1
+// (aligned store of i4, emulated using i8 as the container type)
+//
+// vector.store %src, %dest[%idx_1, %idx_2] : memref<4x8xi4>, vector<8xi4>
+//
+// is rewritten as:
+//
+// %src_bitcast = vector.bitcast %src : vector<8xi4> to vector<4xi8>
+// vector.store %src_bitcast, %dest_bitcast[%idx]
+// : memref<16xi8>, vector<4xi8>
+//
+// EXAMPLE 2
+// (unaligned store of i2, emulated using i8 as the container type)
+//
+// vector.store %src, %dest[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
+//
+// The i2 store is emulated through 2 x RMW sequences. The destination i2 memref
+// is modelled using 3 bytes:
+//
+// Byte 0 Byte 1 Byte 2
+// +----------+----------+----------+
+// | oooooooo | ooooNNNN | NNoooooo |
+// +----------+----------+----------+
+//
+// N - (N)ew entries (i.e. to be overwritten by vector.store)
+// o - (o)ld entries (to be preserved)
+//
+// For the generated output in the non-atomic case, see:
+// * @vector_store_i2_const_index_two_partial_stores`
+// in:
+// * "vector-emulate-narrow-type-unaligned-non-atomic.mlir".
+//
+// NOTE: By default, all RMW sequences are atomic. Set `disableAtomicRMW` to
+// `false` to generate non-atomic RMW sequences.
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
using OpConversionPattern::OpConversionPattern;
@@ -464,7 +502,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
op, "impossible to pack emulated elements into container elements "
"(bit-wise misalignment)");
}
- int numSrcElemsPerDest = containerBits / emulatedBits;
+ int emulatedPerContainerElem = containerBits / emulatedBits;
// Adjust the number of elements to store when emulating narrow types.
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -480,7 +518,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// vector<4xi8>
auto origElements = valueToStore.getType().getNumElements();
- bool isAlignedEmulation = origElements % numSrcElemsPerDest == 0;
+ // Note, per-element-alignment was already verified above.
+ bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -496,9 +535,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
getAsOpFoldResult(adaptor.getIndices()));
std::optional<int64_t> foldedNumFrontPadElems =
- isAlignedEmulation
- ? 0
- : getConstantIntValue(linearizedInfo.intraDataOffset);
+ isFullyAligned ? 0
+ : getConstantIntValue(linearizedInfo.intraDataOffset);
if (!foldedNumFrontPadElems) {
return rewriter.notifyMatchFailure(
@@ -516,10 +554,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// need unaligned emulation because the store address is aligned and the
// source is a whole byte.
bool emulationRequiresPartialStores =
- !isAlignedEmulation || *foldedNumFrontPadElems != 0;
+ !isFullyAligned || *foldedNumFrontPadElems != 0;
if (!emulationRequiresPartialStores) {
// Basic case: storing full bytes.
- auto numElements = origElements / numSrcElemsPerDest;
+ auto numElements = origElements / emulatedPerContainerElem;
auto bitCast = rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numElements, containerElemTy),
op.getValueToStore());
@@ -567,7 +605,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// Build a mask used for rmw.
auto subWidthStoreMaskType =
- VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type());
+ VectorType::get({emulatedPerContainerElem}, rewriter.getI1Type());
auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW;
@@ -576,10 +614,11 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// with the unaligned part so that the rest elements are aligned to width
// boundary.
auto frontSubWidthStoreElem =
- (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
+ (emulatedPerContainerElem - *foldedNumFrontPadElems) %
+ emulatedPerContainerElem;
if (frontSubWidthStoreElem > 0) {
- SmallVector<bool> frontMaskValues(numSrcElemsPerDest, false);
- if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
+ SmallVector<bool> frontMaskValues(emulatedPerContainerElem, false);
+ if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem) {
std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
origElements, true);
frontSubWidthStoreElem = origElements;
@@ -590,7 +629,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto frontMask = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
- currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
+ currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
auto value =
extractSliceIntoByte(rewriter, loc, valueToStore, 0,
frontSubWidthStoreElem, *foldedNumFrontPadElems);
@@ -614,8 +653,9 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// After the previous step, the store address is aligned to the emulated
// width boundary.
int64_t fullWidthStoreSize =
- (origElements - currentSourceIndex) / numSrcElemsPerDest;
- int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
+ (origElements - currentSourceIndex) / emulatedPerContainerElem;
+ int64_t numNonFullWidthElements =
+ fullWidthStoreSize * emulatedPerContainerElem;
if (fullWidthStoreSize > 0) {
auto fullWidthStorePart = staticallyExtractSubvector(
rewriter, loc, valueToStore, currentSourceIndex,
@@ -624,7 +664,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto originType = cast<VectorType>(fullWidthStorePart.getType());
auto memrefElemType = getElementTypeOrSelf(memrefBase.getType());
auto storeType = VectorType::get(
- {originType.getNumElements() / numSrcElemsPerDest}, memrefElemType);
+ {originType.getNumElements() / emulatedPerContainerElem},
+ memrefElemType);
auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
fullWidthStorePart);
rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memrefBase,
@@ -646,7 +687,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
currentSourceIndex, remainingElements, 0);
// Generate back mask.
- auto maskValues = SmallVector<bool>(numSrcElemsPerDest, 0);
+ auto maskValues = SmallVector<bool>(emulatedPerContainerElem, 0);
std::fill_n(maskValues.begin(), remainingElements, 1);
auto backMask = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
@@ -960,7 +1001,8 @@ struct ConvertVectorMaskedLoad final
// subvector at the proper offset after bit-casting.
auto origType = op.getVectorType();
auto origElements = origType.getNumElements();
- bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0;
+ // Note, per-element-alignment was already verified above.
+ bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -975,9 +1017,8 @@ struct ConvertVectorMaskedLoad final
getAsOpFoldResult(adaptor.getIndices()));
std::optional<int64_t> foldedIntraVectorOffset =
- isAlignedEmulation
- ? 0
- : getConstantIntValue(linearizedInfo.intraDataOffset);
+ isFullyAligned ? 0
+ : getConstantIntValue(linearizedInfo.intraDataOffset);
int64_t maxIntraDataOffset =
foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
@@ -1001,7 +1042,7 @@ struct ConvertVectorMaskedLoad final
passthru = dynamicallyInsertSubVector(
rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset,
origElements);
- } else if (!isAlignedEmulation) {
+ } else if (!isFullyAligned) {
passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
*foldedIntraVectorOffset);
}
@@ -1029,7 +1070,7 @@ struct ConvertVectorMaskedLoad final
mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask,
linearizedInfo.intraDataOffset,
origElements);
- } else if (!isAlignedEmulation) {
+ } else if (!isFullyAligned) {
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
*foldedIntraVectorOffset);
}
@@ -1040,7 +1081,7 @@ struct ConvertVectorMaskedLoad final
result = dynamicallyExtractSubVector(
rewriter, loc, result, op.getPassThru(),
linearizedInfo.intraDataOffset, origElements);
- } else if (!isAlignedEmulation) {
+ } else if (!isFullyAligned) {
result = staticallyExtractSubvector(
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
index 1d6263535ae80..d27e99a54529c 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
@@ -1,5 +1,7 @@
// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8 disable-atomic-rmw=true" --cse --split-input-file %s | FileCheck %s
+// NOTE: In this file all RMW stores are non-atomic.
+
// TODO: remove memref.alloc() in the tests to eliminate noises.
// memref.alloc exists here because sub-byte vector data types such as i2
// are currently not supported as input arguments.
@@ -8,121 +10,144 @@
/// vector.store
///----------------------------------------------------------------------------------------
-func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) {
- %0 = memref.alloc() : memref<3x3xi2>
+func.func @vector_store_i2_const_index_two_partial_stores(%src: vector<3xi2>) {
+ %dest = memref.alloc() : memref<3x3xi2>
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
- vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
+ vector.store %src, %dest[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
return
}
-// Emit two non-atomic RMW partial stores. Store 6 bits from the input vector (bits [12:18)),
-// into bytes [1:2] from a 3-byte output memref. Due to partial storing,
-// both bytes are accessed partially through masking.
-
-// CHECK: func @vector_store_i2_const_index_two_partial_stores(
-// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
-// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
-// CHECK: %[[C1:.+]] = arith.constant 1 : index
-
-// Part 1 RMW sequence
-// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]>
-// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
-// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
-// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
-// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
-// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2>
-// CHECK: %[[LOAD:.+]] = vector.load
-// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
-// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[DOWNCAST]]
-// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[SELECT]]
-// CHECK: vector.store %[[UPCAST]], %[[ALLOC]][%[[C1]]]
-
-// Part 2 RMW sequence
-// CHECK: %[[OFFSET:.+]] = arith.addi %[[C1]], %[[C1]] : index
-// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
-// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
-// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]], %[[CST0]]
-// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2>
-// CHECK: %[[CST1:.+]] = arith.constant dense<[true, false, false, false]> : vector<4xi1>
-// CHECK: %[[LOAD2:.+]] = vector.load
-// CHECK: %[[UPCAST2:.+]] = vector.bitcast %[[LOAD2]] : vector<1xi8> to vector<4xi2>
-// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[UPCAST2]]
-// CHECK: %[[DOWNCAST2:.+]] = vector.bitcast %[[SELECT2]]
-// CHECK: vector.store %[[DOWNCAST2]], %[[ALLOC]][%[[OFFSET]]]
-
+// Store 6 bits from the input vector into bytes [1:2] of a 3-byte destination
+// memref, i.e. into bits [12:18) of a 24-bit destintion container
+// (`memref<3x3xi2>` is emulated via `memref<3xi8>`). This requires two
+// non-atomic RMW partial stores. Due to partial storing, both bytes are
+// accessed partially through masking.
+
+// CHECK: func @vector_store_i2_const_index_two_partial_stores(
+// CHECK-SAME: %[[SRC:.+]]: vector<3xi2>)
+
+// CHECK: %[[DEST:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+
+// RMW sequence for Byte 1
+// CHECK: %[[MASK_1:.+]] = arith.constant dense<[false, false, true, true]>
+// CHECK: %[[INIT:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[SRC_SLICE_1:.+]] = vector.extract_strided_slice %[[SRC]]
+// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
+// CHECK: %[[INIT_WITH_SLICE_1:.+]] = vector.insert_strided_slice %[[SRC_SLICE_1]], %[[INIT]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[DEST_BYTE_1:.+]] = vector.load %[[DEST]][%[[C1]]] : memref<3xi8>, vector<1xi8>
+// CHECK: %[[DEST_BYTE_1_AS_I2:.+]] = vector.bitcast %[[DEST_BYTE_1]]
+// CHECK-SAME: vector<1xi8> to vector<4xi2>
+// CHECK: %[[RES_BYTE_1:.+]] = arith.select %[[MASK_1]], %[[INIT_WITH_SLICE_1]], %[[DEST_BYTE_1_AS_I2]]
+// CHECK: %[[RES_BYTE_1_AS_I8:.+]] = vector.bitcast %[[RES_BYTE_1]]
+// CHECK-SAME: vector<4xi2> to vector<1xi8>
+// CHECK: vector.store %[[RES_BYTE_1_AS_I8]], %[[DEST]][%[[C1]]]
+
+// RMW sequence for Byte 2
+// CHECK: %[[OFFSET:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[SRC_SLICE_2:.+]] = vector.extract_strided_slice %[[SRC]]
+// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
+// CHECK: %[[INIT_WITH_SLICE_2:.+]] = vector.insert_strided_slice %[[SRC_SLICE_2]], %[[INIT]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[MASK_2:.+]] = arith.constant dense<[true, false, false, false]> : vector<4xi1>
+// CHECK: %[[DEST_BYTE_2:.+]] = vector.load %[[DEST]][%[[OFFSET]]] : memref<3xi8>, vector<1xi8>
+// CHECK: %[[DEST_BYTE_2_AS_I2:.+]] = vector.bitcast %[[DEST_BYTE_2]]
+// CHECK-SAME: vector<1xi8> to vector<4xi2>
+// CHECK: %[[RES_BYTE_2:.+]] = arith.select %[[MASK_2]], %[[INIT_WITH_SLICE_2]], %[[DEST_BYTE_2_AS_I2]]
+// CHECK: %[[RES_BYTE_2_AS_I8:.+]] = vector.bitcast %[[RES_BYTE_2]]
+// CHECK-SAME: vector<4xi2> to vector<1xi8>
+// CHECK: vector.store %[[RES_BYTE_2_AS_I8]], %[[DEST]][%[[OFFSET]]]
// -----
-func.func @vector_store_i2_two_partial_one_full_stores(%arg0: vector<7xi2>) {
- %0 = memref.alloc() : memref<3x7xi2>
+func.func @vector_store_i2_two_partial_one_full_stores(%src: vector<7xi2>) {
+ %dest = memref.alloc() : memref<3x7xi2>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
- vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2>
+ vector.store %src, %dest[%c1, %c0] :memref<3x7xi2>, vector<7xi2>
return
}
-// In this example, emit two RMW stores and one full-width store.
-
-// CHECK: func @vector_store_i2_two_partial_one_full_stores(
-// CHECK-SAME: %[[ARG0:.+]]:
-// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8>
-// CHECK: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]>
-// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
-// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
-// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]}
-// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
-// CHECK-SAME: {offsets = [3], strides = [1]}
-// First sub-width RMW:
-// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C1]]]
-// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
-// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]]
-// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
-// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C1]]]
+// Store 14 bits from the input vector into bytes [1:3] of a 6-byte destination
+// memref, i.e. bits [15:29) of a 48-bit destination container memref
+// (`memref<3x7xi2>` is emulated via `memref<6xi8>`). This requires two
+// non-atomic RMW stores (for the "boundary" bytes) and one full byte store
+// (for the "middle" byte). Note that partial stores require masking.
+
+// CHECK: func @vector_store_i2_two_partial_one_full_stores(
+// CHECK-SAME: %[[SRC:.+]]:
+
+// CHECK: %[[DEST:.+]] = memref.alloc() : memref<6xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+
+// First partial/RMW store:
+// CHECK: %[[MASK_1:.+]] = arith.constant dense<[false, false, false, true]>
+// CHECK: %[[INIT:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[SRC_SLICE_0:.+]] = vector.extract_strided_slice %[[SRC]]
+// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]}
+// CHECK: %[[INIT_WITH_SLICE_1:.+]] = vector.insert_strided_slice %[[SRC_SLICE_0]], %[[INIT]]
+// CHECK-SAME: {offsets = [3], strides = [1]}
+// CHECK: %[[DEST_BYTE_1:.+]] = vector.load %[[DEST]][%[[C1]]]
+// CHECK: %[[DEST_BYTE_AS_I2:.+]] = vector.bitcast %[[DEST_BYTE_1]]
+// CHECK-SAME: : vector<1xi8> to vector<4xi2>
+// CHECK: %[[RES_BYTE_1:.+]] = arith.select %[[MASK_1]], %[[INIT_WITH_SLICE_1]], %[[DEST_BYTE_AS_I2]]
+// CHECK: %[[RES_BYTE_1_AS_I8:.+]] = vector.bitcast %[[RES_BYTE_1]]
+// CHECK-SAME: : vector<4xi2> to vector<1xi8>
+// CHECK: vector.store %[[RES_BYTE_1_AS_I8]], %[[DEST]][%[[C1]]]
// Full-width store:
-// CHECK: %[[INDEX:.+]] = arith.addi %[[C1]], %[[C1]]
-// CHECK: %[[EXTRACT1:.+]] = vector.extract_strided_slice %[[ARG0]]
-// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]}
-// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EXTRACT1]]
-// CHECK: vector.store %[[BITCAST]], %[[ALLOC]][%[[INDEX]]]
-
-// Second sub-width RMW:
-// CHECK: %[[INDEX2:.+]] = arith.addi %[[INDEX]], %[[C1]]
-// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
-// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]}
-// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]]
-// CHECK-SAME: {offsets = [0], strides = [1]}
-// CHECK: %[[CST1:.+]] = arith.constant dense<[true, true, false, false]>
-// CHECK: %[[LOAD2:.+]] = vector.load %[[ALLOC]][%[[INDEX2]]]
-// CHECK: %[[UPCAST2:.+]] = vector.bitcast %[[LOAD2]]
-// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[UPCAST2]]
-// CHECK: %[[DOWNCAST2:.+]] = vector.bitcast %[[SELECT2]]
-// CHECK: vector.store %[[DOWNCAST2]], %[[ALLOC]][%[[INDEX2]]]
+// CHECK: %[[C2:.+]] = arith.addi %[[C1]], %[[C1]]
+// CHECK: %[[SRC_SLICE_1:.+]] = vector.extract_strided_slice %[[SRC]]
+// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]}
+// CHECK: %[[SRC_SLICE_1_AS_I8:.+]] = vector.bitcast %[[SRC_SLICE_1]]
+// CHECK-SAME: : vector<4xi2> to vector<1xi8>
+// CHECK: vector.store %[[SRC_SLICE_1_AS_I8]], %[[DEST]][%[[C2]]]
+
+// Second partial/RMW store:
+// CHECK: %[[C3:.+]] = arith.addi %[[C2]], %[[C1]]
+// CHECK: %[[SRC_SLICE_2:.+]] = vector.extract_strided_slice %[[SRC]]
+// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]}
+// CHECK: %[[INIT_WITH_SLICE2:.+]] = vector.insert_strided_slice %[[SRC_SLICE_2]]
+// CHECK-SAME: {offsets = [0], strides = [1]}
+// CHECK: %[[MASK_2:.+]] = arith.constant dense<[true, true, false, false]>
+// CHECK: %[[DEST_BYTE_2:.+]] = vector.load %[[DEST]][%[[C3]]]
+// CHECK: %[[DEST_BYTE_2_AS_I2:.+]] = vector.bitcast %[[DEST_BYTE_2]]
+// CHECK: %[[RES_BYTE_2:.+]] = arith.select %[[MASK_2]], %[[INIT_WITH_SLICE2]], %[[DEST_BYTE_2_AS_I2]]
+// CHECK: %[[RES_BYTE_2_AS_I8:.+]] = vector.bitcast %[[RES_BYTE_2]]
+// CHECK-SAME: : vector<4xi2> to vector<1xi8>
+// CHECK: vector.store %[[RES_BYTE_2_AS_I8]], %[[DEST]][%[[C3]]]
// -----
-func.func @vector_store_i2_const_index_one_partial_store(%arg0: vector<1xi2>) {
- %0 = memref.alloc() : memref<4x1xi2>
+func.func @vector_store_i2_const_index_one_partial_store(%src: vector<1xi2>) {
+ %dest = memref.alloc() : memref<4x1xi2>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
- vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2>
+ vector.store %src, %dest[%c1, %c0] :memref<4x1xi2>, vector<1xi2>
return
}
-// in this test, only emit partial RMW store as the store is within one byte.
-
-// CHECK: func @vector_store_i2_const_index_one_partial_store(
-// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>)
-// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8>
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[CST:.+]] = arith.constant dense<[false, true, false, false]>
-// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
-// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]]
-// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2>
-// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C0]]] : memref<1xi8>, vector<1xi8>
-// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
-// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]]
-// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
-// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C0]]]
+// Store 2 bits from the input vector into byte 0 of a 1-byte destination
+// memref, i.e. bits [3:5) of a 8-bit destination container memref
+// (`<memref<4x1xi2>` is emulated via `memref<1xi8>`). This requires one
+// non-atomic RMW.
+
+// CHECK: func @vector_store_i2_const_index_one_partial_store(
+// CHECK-SAME: %[[SRC:.+]]: vector<1xi2>)
+
+// CHECK: %[[DEST:.+]] = memref.alloc() : memref<1xi8>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+
+// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, false, false]>
+// CHECK: %[[INIT:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[INIT_WITH_SLICE:.+]] = vector.insert_strided_slice %[[SRC]], %[[INIT]]
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[DEST_BYTE:.+]] = vector.load %[[DEST]][%[[C0]]] : memref<1xi8>, vector<1xi8>
+// CHECK: %[[DEST_BYTE_AS_I2:.+]] = vector.bitcast %[[DEST_BYTE]]
+// CHECK-SAME: : vector<1xi8> to vector<4xi2>
+// CHECK: %[[RES_BYTE:.+]] = arith.select %[[MASK]], %[[INIT_WITH_SLICE]], %[[DEST_BYTE_AS_I2]]
+// CHECK: %[[RES_BYTE_AS_I8:.+]] = vector.bitcast %[[RES_BYTE]]
+// CHECK-SAME: : vector<4xi2> to vector<1xi8>
+// CHECK: vector.store %[[RES_BYTE_AS_I8]], %[[DEST]][%[[C0]]]
More information about the Mlir-commits
mailing list