[Mlir-commits] [mlir] [mlir][vector] Document `ConvertVectorStore` + unify var names (nfc) (PR #126422)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Sun Feb 9 08:48:11 PST 2025
https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/126422
1. Documents `ConvertVectorStore`.
2. As a follow-on for #123527, renames `isAlignedEmulation` to
`isFullyAligned` and `numSrcElemsPerDest` to
`emulatedPerContainerElem`.
>From ffd3552b9f13b3f538753442e75ffee0dd9e69c9 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sun, 9 Feb 2025 16:33:51 +0000
Subject: [PATCH] [mlir][vector] Document `ConvertVectorStore` + unify var
names (nfc)
1. Documents `ConvertVectorStore`.
2. As a follow-on for #123527, renames `isAlignedEmulation` to
`isFullyAligned` and `numSrcElemsPerDest` to
`emulatedPerContainerElem`.
---
.../Transforms/VectorEmulateNarrowType.cpp | 121 +++++++++++++++---
1 file changed, 101 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index bf1ecd7d4559caf..bb7449d85f079a5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -432,7 +432,86 @@ namespace {
// ConvertVectorStore
//===----------------------------------------------------------------------===//
-// TODO: Document-me
+// Emulate vector.store using a multi-byte container type
+//
+// The container type is obtained through Op adaptor and would normally be
+// generated via `NarrowTypeEmulationConverter`.
+//
+// EXAMPLE 1
+// (aligned store of i4, emulated using i8)
+//
+// vector.store %src, %dest[%idx_1, %idx_2] : memref<4x8xi4>, vector<8xi4>
+//
+// is rewritten as:
+//
+// %src_bitcast = vector.bitcast %src : vector<8xi4> to vector<4xi8>
+// vector.store %src_bitcast, %dest_bitcast[%idx]
+// : memref<16xi8>, vector<4xi8>
+//
+// EXAMPLE 2
+// (unaligned store of i2, emulated using i8, non-atomic)
+//
+// vector.store %src, %dest[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
+//
+// The i2 store is emulated through 2 x RMW sequences. The destination i2 memref
+// is modelled using 3 bytes:
+//
+// Byte 0 Byte 1 Byte 2
+// +----------+----------+----------+
+// | oooooooo | ooooNNNN | NNoooooo |
+// +----------+----------+----------+
+//
+// N - (N)ew entries (i.e. to be overwritten by vector.store)
+// o - (o)ld entries (to be preserved)
+//
+// The following 2 RMW sequences will be generated:
+//
+// %init = arith.constant dense<0> : vector<4xi2>
+//
+// (RMW sequence for Byte 1)
+// (Mask for 4 x i2 elements, i.e. a byte)
+// %mask_1 = arith.constant dense<[false, false, true, true]>
+// %src_slice_1 = vector.extract_strided_slice %src
+// {offsets = [0], sizes = [2], strides = [1]}
+// : vector<3xi2> to vector<2xi2>
+// %init_with_slice_1 = vector.insert_strided_slice %src_slice_1, %init
+// {offsets = [2], strides = [1]}
+// : vector<2xi2> into vector<4xi2>
+// %dest_byte_1 = vector.load %dest[%c1]
+// %dest_byte_1_as_i2 = vector.bitcast %dest_byte_1
+// : vector<1xi8> to vector<4xi2>
+// %res_byte_1 = arith.select %mask_1, %init_with_slice_1, %dest_byte_1_as_i2
+// %res_byte_1_as_i8 = vector.bitcast %res_byte_1
+// vector.store %res_byte_1_as_i8, %dest[1]
+
+// (RMW sequence for Byte 22)
+// (Mask for 4 x i2 elements, i.e. a byte)
+// %mask_2 = arith.constant dense<[true, false, false, false]>
+// %src_slice_2 = vector.extract_strided_slice %src
+// : {offsets = [2], sizes = [1], strides = [1]}
+// : vector<3xi2> to vector<1xi2>
+// %initi_with_slice_2 = vector.insert_strided_slice %src_slice_2, %init
+// : {offsets = [0], strides = [1]}
+// : vector<1xi2> into vector<4xi2>
+// %dest_byte_2 = vector.load %dest[%c2]
+// %dest_byte_2_as_i2 = vector.bitcast %dest_byte_2
+// : vector<1xi8> to vector<4xi2>
+// vector<4xi2> %res_byte_2 = arith.select %ask_2, %init_with_slice_2,
+// %dest_byte_2_as_i2 %res_byte_1_as_i8 = vector.bitcast %rest_byte_2
+// vector.store %res_byte_1_as_i8, %dest[2]
+//
+// NOTE: Unlike EXAMPLE 1, this case requires index re-calculation.
+// NOTE: This example assumes that `disableAtomicRMW` was set.
+//
+// EXAMPLE 3
+// (unaligned store of i2, emulated using i8, atomic)
+//
+// Similar to EXAMPLE 2, with the addition of
+// * `memref.generic_atomic_rmw`,
+// to guarantee atomicity. The actual output is skipped for brevity.
+//
+// NOTE: by default, all RMW sequences are atomic. Set `disableAtomicRMW` to
+// `false` to generate non-atomic RMW sequences.
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
using OpConversionPattern::OpConversionPattern;
@@ -464,7 +543,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
op, "impossible to pack emulated elements into container elements "
"(bit-wise misalignment)");
}
- int numSrcElemsPerDest = containerBits / emulatedBits;
+ int emulatedPerContainerElem = containerBits / emulatedBits;
// Adjust the number of elements to store when emulating narrow types.
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -480,7 +559,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// vector<4xi8>
auto origElements = valueToStore.getType().getNumElements();
- bool isAlignedEmulation = origElements % numSrcElemsPerDest == 0;
+ // Note, per-element-alignment was already verified above.
+ bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -496,7 +576,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
getAsOpFoldResult(adaptor.getIndices()));
std::optional<int64_t> foldedNumFrontPadElems =
- isAlignedEmulation
+ isFullyAligned
? 0
: getConstantIntValue(linearizedInfo.intraDataOffset);
@@ -516,10 +596,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// need unaligned emulation because the store address is aligned and the
// source is a whole byte.
bool emulationRequiresPartialStores =
- !isAlignedEmulation || *foldedNumFrontPadElems != 0;
+ !isFullyAligned || *foldedNumFrontPadElems != 0;
if (!emulationRequiresPartialStores) {
// Basic case: storing full bytes.
- auto numElements = origElements / numSrcElemsPerDest;
+ auto numElements = origElements / emulatedPerContainerElem;
auto bitCast = rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numElements, containerElemTy),
op.getValueToStore());
@@ -567,7 +647,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// Build a mask used for rmw.
auto subWidthStoreMaskType =
- VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type());
+ VectorType::get({emulatedPerContainerElem}, rewriter.getI1Type());
auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW;
@@ -576,10 +656,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// with the unaligned part so that the rest elements are aligned to width
// boundary.
auto frontSubWidthStoreElem =
- (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
+ (emulatedPerContainerElem - *foldedNumFrontPadElems) % emulatedPerContainerElem;
if (frontSubWidthStoreElem > 0) {
- SmallVector<bool> frontMaskValues(numSrcElemsPerDest, false);
- if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
+ SmallVector<bool> frontMaskValues(emulatedPerContainerElem, false);
+ if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem) {
std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
origElements, true);
frontSubWidthStoreElem = origElements;
@@ -590,7 +670,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto frontMask = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
- currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
+ currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
auto value =
extractSliceIntoByte(rewriter, loc, valueToStore, 0,
frontSubWidthStoreElem, *foldedNumFrontPadElems);
@@ -614,8 +694,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// After the previous step, the store address is aligned to the emulated
// width boundary.
int64_t fullWidthStoreSize =
- (origElements - currentSourceIndex) / numSrcElemsPerDest;
- int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
+ (origElements - currentSourceIndex) / emulatedPerContainerElem;
+ int64_t numNonFullWidthElements = fullWidthStoreSize * emulatedPerContainerElem;
if (fullWidthStoreSize > 0) {
auto fullWidthStorePart = staticallyExtractSubvector(
rewriter, loc, valueToStore, currentSourceIndex,
@@ -624,7 +704,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto originType = cast<VectorType>(fullWidthStorePart.getType());
auto memrefElemType = getElementTypeOrSelf(memrefBase.getType());
auto storeType = VectorType::get(
- {originType.getNumElements() / numSrcElemsPerDest}, memrefElemType);
+ {originType.getNumElements() / emulatedPerContainerElem}, memrefElemType);
auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
fullWidthStorePart);
rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memrefBase,
@@ -646,7 +726,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
currentSourceIndex, remainingElements, 0);
// Generate back mask.
- auto maskValues = SmallVector<bool>(numSrcElemsPerDest, 0);
+ auto maskValues = SmallVector<bool>(emulatedPerContainerElem, 0);
std::fill_n(maskValues.begin(), remainingElements, 1);
auto backMask = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
@@ -960,7 +1040,8 @@ struct ConvertVectorMaskedLoad final
// subvector at the proper offset after bit-casting.
auto origType = op.getVectorType();
auto origElements = origType.getNumElements();
- bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0;
+ // Note, per-element-alignment was already verified above.
+ bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -975,7 +1056,7 @@ struct ConvertVectorMaskedLoad final
getAsOpFoldResult(adaptor.getIndices()));
std::optional<int64_t> foldedIntraVectorOffset =
- isAlignedEmulation
+ isFullyAligned
? 0
: getConstantIntValue(linearizedInfo.intraDataOffset);
@@ -1001,7 +1082,7 @@ struct ConvertVectorMaskedLoad final
passthru = dynamicallyInsertSubVector(
rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset,
origElements);
- } else if (!isAlignedEmulation) {
+ } else if (!isFullyAligned) {
passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
*foldedIntraVectorOffset);
}
@@ -1029,7 +1110,7 @@ struct ConvertVectorMaskedLoad final
mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask,
linearizedInfo.intraDataOffset,
origElements);
- } else if (!isAlignedEmulation) {
+ } else if (!isFullyAligned) {
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
*foldedIntraVectorOffset);
}
@@ -1040,7 +1121,7 @@ struct ConvertVectorMaskedLoad final
result = dynamicallyExtractSubVector(
rewriter, loc, result, op.getPassThru(),
linearizedInfo.intraDataOffset, origElements);
- } else if (!isAlignedEmulation) {
+ } else if (!isFullyAligned) {
result = staticallyExtractSubvector(
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
More information about the Mlir-commits
mailing list