[Mlir-commits] [mlir] [MLIR] Implement emulation of static indexing subbyte type vector stores (PR #115922)
Alan Li
llvmlistbot at llvm.org
Fri Jan 24 03:13:29 PST 2025
https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/115922
>From 358ca604a2343bed342af0e53e9ab13f90f1c035 Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Fri, 10 Jan 2025 10:17:26 +0000
Subject: [PATCH 1/8] [MLIR] atomic emulation of static indexing subbyte type
vector stores
This patch enables unaligned, statically indexed storing of vectors with sub emulation width element types.
To illustrate the mechanism, consider the example of storing vector<7xi2> into memref<3x7xi2>[1, 0].
In this case the linearized indices of those bits being overwritten are [14, 28), which are:
* the last 2 bits of byte no.2
* byte no.3
* first 4 bits of byte no.4
Because memory accesses are in bytes, byte no.2 and no.4 in the above example are only being modified partially.
In the case of multi-threading scenario, in order to avoid data contention, these two bytes must be handled atomically.
---
.../Transforms/VectorEmulateNarrowType.cpp | 286 +++++++++++++++---
.../vector-emulate-narrow-type-unaligned.mlir | 137 +++++++++
2 files changed, 378 insertions(+), 45 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 95064083b21d44..76691966ee66b7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -45,6 +45,9 @@ using namespace mlir;
#define DBGSNL() (llvm::dbgs() << "\n")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+using VectorValue = TypedValue<VectorType>;
+using MemRefValue = TypedValue<MemRefType>;
+
/// Returns a compressed mask for the emulated vector. For example, when
/// emulating an eight-element `i8` vector with `i32` (i.e. when the source
/// elements span two dest elements), this method compresses `vector<8xi1>`
@@ -194,13 +197,10 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
/// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
/// emitting `vector.extract_strided_slice`.
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
- VectorType extractType, Value source,
- int64_t frontOffset,
+ Value source, int64_t frontOffset,
int64_t subvecSize) {
auto vectorType = cast<VectorType>(source.getType());
- assert((vectorType.getRank() == 1 && extractType.getRank() == 1) &&
- "expected 1-D source and destination types");
- (void)vectorType;
+ assert(vectorType.getRank() == 1 && "expected 1-D source types");
assert(frontOffset + subvecSize <= vectorType.getNumElements() &&
"subvector out of bounds");
@@ -211,9 +211,12 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
auto offsets = rewriter.getI64ArrayAttr({frontOffset});
auto sizes = rewriter.getI64ArrayAttr({subvecSize});
auto strides = rewriter.getI64ArrayAttr({1});
+
+ auto resultVectorType =
+ VectorType::get({subvecSize}, vectorType.getElementType());
return rewriter
- .create<vector::ExtractStridedSliceOp>(loc, extractType, source, offsets,
- sizes, strides)
+ .create<vector::ExtractStridedSliceOp>(loc, resultVectorType, source,
+ offsets, sizes, strides)
->getResult(0);
}
@@ -237,8 +240,8 @@ static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc,
/// function emits multiple `vector.extract` and `vector.insert` ops, so only
/// use it when `offset` cannot be folded into a constant value.
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
- TypedValue<VectorType> source,
- Value dest, OpFoldResult offset,
+ VectorValue source, Value dest,
+ OpFoldResult offset,
int64_t numElementsToExtract) {
for (int i = 0; i < numElementsToExtract; ++i) {
Value extractLoc =
@@ -255,8 +258,8 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
/// Inserts a 1-D subvector into a 1-D `dest` vector at index `destOffsetVar`.
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
- TypedValue<VectorType> source,
- Value dest, OpFoldResult destOffsetVar,
+ VectorValue source, Value dest,
+ OpFoldResult destOffsetVar,
size_t length) {
assert(length > 0 && "length must be greater than 0");
Value destOffsetVal =
@@ -277,11 +280,12 @@ static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
/// specifically, use `emulatedElemType` for loading a vector of `origElemType`.
/// The load location is given by `base` and `linearizedIndices`, and the
/// load size is given by `numEmulatedElementsToLoad`.
-static TypedValue<VectorType>
-emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
- OpFoldResult linearizedIndices,
- int64_t numEmultedElementsToLoad, Type origElemType,
- Type emulatedElemType) {
+static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc,
+ Value base,
+ OpFoldResult linearizedIndices,
+ int64_t numEmultedElementsToLoad,
+ Type origElemType,
+ Type emulatedElemType) {
auto scale = emulatedElemType.getIntOrFloatBitWidth() /
origElemType.getIntOrFloatBitWidth();
auto newLoad = rewriter.create<vector::LoadOp>(
@@ -292,6 +296,89 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
newLoad);
}
+/// Selects values from two sources based on a mask, and casts the result to a
+/// new type.
+static Value selectAndCast(OpBuilder &builder, Location loc,
+ VectorType castIntoType, Value mask, Value trueValue,
+ Value falseValue) {
+ Value maskedValue =
+ builder.create<arith::SelectOp>(loc, mask, trueValue, falseValue);
+ return builder.create<vector::BitCastOp>(loc, castIntoType, maskedValue);
+}
+
+/// Emits `memref.generic_atomic_rmw` op to store a subbyte-sized value to a
+/// byte in memory, with a mask. The `valueToStore` is a vector of subbyte-sized
+/// elements, with size of 8 bits, and the mask is used to select which elements
+/// to store.
+///
+/// Inputs:
+/// linearizedMemref = |2|2|2|2| : <4xi2> (<1xi8>)
+/// linearizedIndex = 2
+/// valueToStore = |3|3|3|3| : vector<4xi2>
+/// mask = |0|0|1|1| : vector<4xi1>
+///
+/// Result:
+/// linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
+static void atomicStore(OpBuilder &builder, Location loc,
+ MemRefValue linearizedMemref, Value linearizedIndex,
+ VectorValue valueToStore, Value mask) {
+ assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
+
+ // Create an atomic load-modify-write region using
+ // `memref.generic_atomic_rmw`.
+ auto atomicOp = builder.create<memref::GenericAtomicRMWOp>(
+ loc, linearizedMemref, ValueRange{linearizedIndex});
+ Value origValue = atomicOp.getCurrentValue();
+
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(atomicOp.getBody());
+
+ // Load the original value from memory, and cast it to the original element
+ // type.
+ auto oneElemVecType = VectorType::get({1}, origValue.getType());
+ Value origVecValue = builder.create<vector::FromElementsOp>(
+ loc, oneElemVecType, ValueRange{origValue});
+ origVecValue = builder.create<vector::BitCastOp>(loc, valueToStore.getType(),
+ origVecValue);
+
+ // Construct the final masked value and yield it.
+ Value maskedValue = selectAndCast(builder, loc, oneElemVecType, mask,
+ valueToStore, origVecValue);
+ auto scalarMaskedValue =
+ builder.create<vector::ExtractOp>(loc, maskedValue, 0);
+ builder.create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
+}
+
+/// Extract `sliceNumElements` from source `vector` at `sliceOffset`,
+/// and insert it into an empty vector at offset `byteOffset`.
+/// Inputs:
+/// vector = |1|2|3|4| : vector<4xi2>
+/// sliceOffset = 1
+/// sliceNumElements = 2
+/// byteOffset = 2
+/// Output:
+/// vector = |0|0|2|3| : vector<4xi2>
+static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
+ Location loc, VectorValue vector,
+ int64_t sliceOffset, int64_t sliceNumElements,
+ int64_t byteOffset) {
+ assert(vector.getType().getRank() == 1 && "expected 1-D vector");
+ auto vectorElementType = vector.getType().getElementType();
+ assert(
+ sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&
+ "sliceNumElements * vector element size must be less than or equal to 8");
+ assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
+ "vector element must be a valid sub-byte type");
+ auto scale = 8 / vectorElementType.getIntOrFloatBitWidth();
+ auto emptyByteVector = rewriter.create<arith::ConstantOp>(
+ loc, VectorType::get({scale}, vectorElementType),
+ rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType)));
+ auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
+ sliceOffset, sliceNumElements);
+ return staticallyInsertSubvector(rewriter, loc, extracted, emptyByteVector,
+ byteOffset);
+}
+
namespace {
//===----------------------------------------------------------------------===//
@@ -301,6 +388,9 @@ namespace {
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
using OpConversionPattern::OpConversionPattern;
+ ConvertVectorStore(MLIRContext *context, bool useAtomicWrites)
+ : OpConversionPattern<vector::StoreOp>(context) {}
+
LogicalResult
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -312,8 +402,9 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto loc = op.getLoc();
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
- Type oldElementType = op.getValueToStore().getType().getElementType();
- Type newElementType = convertedType.getElementType();
+ auto valueToStore = cast<VectorValue>(op.getValueToStore());
+ auto oldElementType = valueToStore.getType().getElementType();
+ auto newElementType = convertedType.getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
int dstBits = newElementType.getIntOrFloatBitWidth();
@@ -321,7 +412,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
return rewriter.notifyMatchFailure(
op, "only dstBits % srcBits == 0 supported");
}
- int scale = dstBits / srcBits;
+ int numSrcElemsPerDest = dstBits / srcBits;
// Adjust the number of elements to store when emulating narrow types.
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -336,15 +427,15 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
// vector<4xi8>
- auto origElements = op.getValueToStore().getType().getNumElements();
- if (origElements % scale != 0)
- return failure();
+ auto origElements = valueToStore.getType().getNumElements();
+ bool isUnalignedEmulation = origElements % numSrcElemsPerDest != 0;
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
OpFoldResult linearizedIndices;
- std::tie(std::ignore, linearizedIndices) =
+ memref::LinearizedMemRefInfo linearizedInfo;
+ std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
rewriter, loc, srcBits, dstBits,
stridedMetadata.getConstifiedMixedOffset(),
@@ -352,14 +443,122 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
stridedMetadata.getConstifiedMixedStrides(),
getAsOpFoldResult(adaptor.getIndices()));
- auto numElements = origElements / scale;
- auto bitCast = rewriter.create<vector::BitCastOp>(
- loc, VectorType::get(numElements, newElementType),
- op.getValueToStore());
+ auto foldedNumFrontPadElems =
+ isUnalignedEmulation
+ ? getConstantIntValue(linearizedInfo.intraDataOffset)
+ : 0;
+
+ if (!foldedNumFrontPadElems) {
+ // Unimplemented case for dynamic front padding size != 0
+ return failure();
+ }
+
+ auto linearizedMemref = cast<MemRefValue>(adaptor.getBase());
+
+ // Shortcut: conditions when subbyte store at the front is not needed:
+ // 1. The source vector size is multiple of byte size
+ // 2. The address of the store is aligned to the emulated width boundary
+ if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0) {
+ auto numElements = origElements / numSrcElemsPerDest;
+ auto bitCast = rewriter.create<vector::BitCastOp>(
+ loc, VectorType::get(numElements, newElementType),
+ op.getValueToStore());
+ rewriter.replaceOpWithNewOp<vector::StoreOp>(
+ op, bitCast.getResult(), linearizedMemref,
+ getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+ return success();
+ }
+
+ // The index into the target memref we are storing to
+ Value currentDestIndex =
+ getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
+ auto subWidthStoreMaskType =
+ VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type());
+ // The index into the source vector we are currently processing
+ auto currentSourceIndex = 0;
+
+ // 1. Partial width store for the first byte, when the store address is not
+ // aligned to emulated width boundary, deal with the unaligned part so that
+ // the rest elements are aligned to width boundary.
+ auto frontSubWidthStoreElem =
+ (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
+ if (frontSubWidthStoreElem > 0) {
+ SmallVector<bool> frontMaskValues(numSrcElemsPerDest, false);
+ if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
+ std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
+ origElements, true);
+ frontSubWidthStoreElem = origElements;
+ } else {
+ std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
+ *foldedNumFrontPadElems, true);
+ }
+ auto frontMask = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
+
+ currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
+ auto value =
+ extractSliceIntoByte(rewriter, loc, valueToStore, 0,
+ frontSubWidthStoreElem, *foldedNumFrontPadElems);
- rewriter.replaceOpWithNewOp<vector::StoreOp>(
- op, bitCast.getResult(), adaptor.getBase(),
- getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+ atomicStore(rewriter, loc, linearizedMemref, currentDestIndex,
+ cast<VectorValue>(value), frontMask.getResult());
+ }
+
+ if (currentSourceIndex >= origElements) {
+ rewriter.eraseOp(op);
+ return success();
+ }
+
+ // Increment the destination index by 1 to align to the emulated width
+ // boundary.
+ auto constantOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ currentDestIndex = rewriter.create<arith::AddIOp>(
+ loc, rewriter.getIndexType(), currentDestIndex, constantOne);
+
+ // 2. Full width store. After the previous step, the store address is
+ // aligned to the emulated width boundary.
+ int64_t fullWidthStoreSize =
+ (origElements - currentSourceIndex) / numSrcElemsPerDest;
+ int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
+ if (fullWidthStoreSize > 0) {
+ auto fullWidthStorePart = staticallyExtractSubvector(
+ rewriter, loc, valueToStore, currentSourceIndex,
+ numNonFullWidthElements);
+
+ auto originType = cast<VectorType>(fullWidthStorePart.getType());
+ auto memrefElemType = getElementTypeOrSelf(linearizedMemref.getType());
+ auto storeType = VectorType::get(
+ {originType.getNumElements() / numSrcElemsPerDest}, memrefElemType);
+ auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
+ fullWidthStorePart);
+ rewriter.create<vector::StoreOp>(loc, bitCast.getResult(),
+ linearizedMemref, currentDestIndex);
+
+ currentSourceIndex += numNonFullWidthElements;
+ currentDestIndex = rewriter.create<arith::AddIOp>(
+ loc, rewriter.getIndexType(), currentDestIndex,
+ rewriter.create<arith::ConstantIndexOp>(loc, fullWidthStoreSize));
+ }
+
+ // 3. Deal with trailing elements that are aligned to the emulated width,
+ // but their length is smaller than the emulated width.
+ auto remainingElements = origElements - currentSourceIndex;
+ if (remainingElements != 0) {
+ auto subWidthStorePart =
+ extractSliceIntoByte(rewriter, loc, cast<VectorValue>(valueToStore),
+ currentSourceIndex, remainingElements, 0);
+
+ // Generate back mask
+ auto maskValues = SmallVector<bool>(numSrcElemsPerDest, 0);
+ std::fill_n(maskValues.begin(), remainingElements, 1);
+ auto backMask = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
+
+ atomicStore(rewriter, loc, linearizedMemref, currentDestIndex,
+ cast<VectorValue>(subWidthStorePart), backMask.getResult());
+ }
+
+ rewriter.eraseOp(op);
return success();
}
};
@@ -564,12 +763,11 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
auto resultVector = rewriter.create<arith::ConstantOp>(
loc, op.getType(), rewriter.getZeroAttr(op.getType()));
result = dynamicallyExtractSubVector(
- rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
+ rewriter, loc, cast<VectorValue>(result), resultVector,
linearizedInfo.intraDataOffset, origElements);
} else if (isUnalignedEmulation) {
- result =
- staticallyExtractSubvector(rewriter, loc, op.getType(), result,
- *foldedIntraVectorOffset, origElements);
+ result = staticallyExtractSubvector(
+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
rewriter.replaceOp(op, result);
return success();
@@ -685,8 +883,8 @@ struct ConvertVectorMaskedLoad final
loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
if (!foldedIntraVectorOffset) {
passthru = dynamicallyInsertSubVector(
- rewriter, loc, dyn_cast<TypedValue<VectorType>>(passthru),
- emptyVector, linearizedInfo.intraDataOffset, origElements);
+ rewriter, loc, cast<VectorValue>(passthru), emptyVector,
+ linearizedInfo.intraDataOffset, origElements);
} else if (isUnalignedEmulation) {
passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
*foldedIntraVectorOffset);
@@ -713,7 +911,7 @@ struct ConvertVectorMaskedLoad final
loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
if (!foldedIntraVectorOffset) {
mask = dynamicallyInsertSubVector(
- rewriter, loc, dyn_cast<TypedValue<VectorType>>(mask), emptyMask,
+ rewriter, loc, cast<VectorValue>(mask), emptyMask,
linearizedInfo.intraDataOffset, origElements);
} else if (isUnalignedEmulation) {
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
@@ -724,12 +922,11 @@ struct ConvertVectorMaskedLoad final
rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
if (!foldedIntraVectorOffset) {
result = dynamicallyExtractSubVector(
- rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
- op.getPassThru(), linearizedInfo.intraDataOffset, origElements);
+ rewriter, loc, cast<VectorValue>(result), op.getPassThru(),
+ linearizedInfo.intraDataOffset, origElements);
} else if (isUnalignedEmulation) {
- result =
- staticallyExtractSubvector(rewriter, loc, op.getType(), result,
- *foldedIntraVectorOffset, origElements);
+ result = staticallyExtractSubvector(
+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
rewriter.replaceOp(op, result);
@@ -812,9 +1009,8 @@ struct ConvertVectorTransferRead final
linearizedInfo.intraDataOffset,
origElements);
} else if (isUnalignedEmulation) {
- result =
- staticallyExtractSubvector(rewriter, loc, op.getType(), result,
- *foldedIntraVectorOffset, origElements);
+ result = staticallyExtractSubvector(
+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
rewriter.replaceOp(op, result);
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 4332e80feed421..b01f9165d9eb74 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -356,3 +356,140 @@ func.func @vector_maskedload_i2_constant_mask_unaligned(%passthru: vector<5xi2>)
// CHECK: %[[RESULT:.+]] = vector.extract_strided_slice %[[SELECT]]
// CHECK-SAME: {offsets = [1], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2>
// CHECK: return %[[RESULT]] : vector<5xi2>
+
+///----------------------------------------------------------------------------------------
+/// vector.store
+///----------------------------------------------------------------------------------------
+
+func.func @vector_store_i2_const_index_two_atomic(%arg0: vector<3xi2>) {
+ %0 = memref.alloc() : memref<3x3xi2>
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
+ return
+}
+
+// In this example, emit 2 atomic RMWs.
+// Load from bit [12:18), byte [1:2] of total 3 bytes, both bytes needs rmw.
+
+// CHECK-LABEL: func @vector_store_i2_const_index_two_atomic(
+// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]> : vector<4xi1>
+// CHECK: %[[CST_0:.+]] = arith.constant dense<0> : vector<4xi2>
+
+// Part 1 atomic RMW sequence
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST_0]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<3xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
+
+// Part 2 atomic RMW sequence
+// CHECK: %[[ADDR2:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT3]], %[[CST_0]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[CST1:.+]] = arith.constant dense<[true, false, false, false]> : vector<4xi1>
+// CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDR2]]] : memref<3xi8> {
+// CHECK: %[[ARG2:.+]]: i8):
+// CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8>
+// CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST4]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST5:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT4:.+]] = vector.extract %[[BITCAST5]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT4]] : i8
+
+// -----
+
+func.func @vector_store_i2_atomic(%arg0: vector<7xi2>) {
+ %0 = memref.alloc() : memref<3x7xi2>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2>
+ return
+}
+
+// In this example, emit 2 atomic RMWs and 1 non-atomic store:
+// CHECK-LABEL: func @vector_store_i2_atomic(
+// CHECK-SAME: %[[ARG0:.+]]: vector<7xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]> : vector<4xi1>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+
+// First atomic RMW:
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]} : vector<7xi2> to vector<1xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [3], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<6xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
+
+// Non-atomic store:
+// CHECK: %[[ADDR:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]} : vector<7xi2> to vector<4xi2>
+// CHECK: %[[BITCAST3:.+]] = vector.bitcast %[[EXTRACT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: vector.store %[[BITCAST3]], %[[ALLOC]][%[[ADDR]]] : memref<6xi8>, vector<1xi8>
+
+// Second atomic RMW:
+// CHECK: %[[ADDR2:.+]] = arith.addi %[[ADDR]], %[[C1]] : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]} : vector<7xi2> to vector<2xi2>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT3]], %[[CST0]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[CST1:.+]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
+// CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDR2]]] : memref<6xi8> {
+// CHECK: %[[ARG2:.+]]: i8):
+// CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8>
+// CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST4]] :
+// CHECK-SAME: vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST5:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT4:.+]] = vector.extract %[[BITCAST5]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT4]] : i8
+
+// -----
+
+func.func @vector_store_i2_single_atomic(%arg0: vector<1xi2>) {
+ %0 = memref.alloc() : memref<4x1xi2>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2>
+ return
+}
+
+// In this example, only emit 1 atomic store
+// CHECK-LABEL: func @vector_store_i2_single_atomic(
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, true, false, false]> : vector<4xi1>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]]
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2>
+
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C0]]] : memref<1xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
>From 70de874277ebba717ea05531192015a7e7230242 Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Sat, 11 Jan 2025 09:19:39 +0000
Subject: [PATCH 2/8] Update according comments
---
.../Transforms/VectorEmulateNarrowType.cpp | 65 +++++++++++--------
.../vector-emulate-narrow-type-unaligned.mlir | 24 +++----
2 files changed, 51 insertions(+), 38 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 76691966ee66b7..9838efdeb2cc22 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -296,38 +296,49 @@ static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc,
newLoad);
}
-/// Selects values from two sources based on a mask, and casts the result to a
-/// new type.
-static Value selectAndCast(OpBuilder &builder, Location loc,
- VectorType castIntoType, Value mask, Value trueValue,
- Value falseValue) {
- Value maskedValue =
+/// Downcast two values to `downcastType`, then select values
+/// based on `mask`, and casts the result to `upcastType`.
+static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc,
+ VectorType downcastType,
+ VectorType upcastType, Value mask,
+ Value trueValue, Value falseValue) {
+ assert(
+ downcastType.getNumElements() * downcastType.getElementTypeBitWidth() ==
+ upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
+ "expected upcastType size to be twice the size of downcastType");
+ if (trueValue.getType() != downcastType)
+ trueValue = builder.create<vector::BitCastOp>(loc, downcastType, trueValue);
+ if (falseValue.getType() != downcastType)
+ falseValue =
+ builder.create<vector::BitCastOp>(loc, downcastType, falseValue);
+ Value selectedType =
builder.create<arith::SelectOp>(loc, mask, trueValue, falseValue);
- return builder.create<vector::BitCastOp>(loc, castIntoType, maskedValue);
+ // Upcast the selected value to the new type.
+ return builder.create<vector::BitCastOp>(loc, upcastType, selectedType);
}
/// Emits `memref.generic_atomic_rmw` op to store a subbyte-sized value to a
-/// byte in memory, with a mask. The `valueToStore` is a vector of subbyte-sized
-/// elements, with size of 8 bits, and the mask is used to select which elements
-/// to store.
+/// byte in `linearizedMemref`, with a mask. The `valueToStore` is a vector of
+/// subbyte-sized elements, with size of 8 bits, and the mask is used to select
+/// which elements to store.
///
/// Inputs:
/// linearizedMemref = |2|2|2|2| : <4xi2> (<1xi8>)
-/// linearizedIndex = 2
+/// storeIdx = 2
/// valueToStore = |3|3|3|3| : vector<4xi2>
/// mask = |0|0|1|1| : vector<4xi1>
///
/// Result:
/// linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
static void atomicStore(OpBuilder &builder, Location loc,
- MemRefValue linearizedMemref, Value linearizedIndex,
+ MemRefValue linearizedMemref, Value storeIdx,
VectorValue valueToStore, Value mask) {
assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
// Create an atomic load-modify-write region using
// `memref.generic_atomic_rmw`.
auto atomicOp = builder.create<memref::GenericAtomicRMWOp>(
- loc, linearizedMemref, ValueRange{linearizedIndex});
+ loc, linearizedMemref, ValueRange{storeIdx});
Value origValue = atomicOp.getCurrentValue();
OpBuilder::InsertionGuard guard(builder);
@@ -338,30 +349,30 @@ static void atomicStore(OpBuilder &builder, Location loc,
auto oneElemVecType = VectorType::get({1}, origValue.getType());
Value origVecValue = builder.create<vector::FromElementsOp>(
loc, oneElemVecType, ValueRange{origValue});
- origVecValue = builder.create<vector::BitCastOp>(loc, valueToStore.getType(),
- origVecValue);
// Construct the final masked value and yield it.
- Value maskedValue = selectAndCast(builder, loc, oneElemVecType, mask,
- valueToStore, origVecValue);
+ Value maskedValue =
+ downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
+ oneElemVecType, mask, valueToStore, origVecValue);
auto scalarMaskedValue =
builder.create<vector::ExtractOp>(loc, maskedValue, 0);
builder.create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
}
-/// Extract `sliceNumElements` from source `vector` at `sliceOffset`,
-/// and insert it into an empty vector at offset `byteOffset`.
+/// Extract `sliceNumElements` from source `vector` at `extractOffset`,
+/// and insert it into an empty vector at `insertOffset`.
/// Inputs:
-/// vector = |1|2|3|4| : vector<4xi2>
-/// sliceOffset = 1
+/// vec_in = |0|1|2|3| : vector<4xi2>
+/// extractOffset = 1
/// sliceNumElements = 2
-/// byteOffset = 2
+/// insertOffset = 2
/// Output:
-/// vector = |0|0|2|3| : vector<4xi2>
+/// vec_out = |0|0|1|2| : vector<4xi2>
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
Location loc, VectorValue vector,
- int64_t sliceOffset, int64_t sliceNumElements,
- int64_t byteOffset) {
+ int64_t extractOffset,
+ int64_t sliceNumElements,
+ int64_t insertOffset) {
assert(vector.getType().getRank() == 1 && "expected 1-D vector");
auto vectorElementType = vector.getType().getElementType();
assert(
@@ -374,9 +385,9 @@ static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
loc, VectorType::get({scale}, vectorElementType),
rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType)));
auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
- sliceOffset, sliceNumElements);
+ extractOffset, sliceNumElements);
return staticallyInsertSubvector(rewriter, loc, extracted, emptyByteVector,
- byteOffset);
+ insertOffset);
}
namespace {
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index b01f9165d9eb74..a80ab7b7e4166e 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -361,25 +361,27 @@ func.func @vector_maskedload_i2_constant_mask_unaligned(%passthru: vector<5xi2>)
/// vector.store
///----------------------------------------------------------------------------------------
-func.func @vector_store_i2_const_index_two_atomic(%arg0: vector<3xi2>) {
- %0 = memref.alloc() : memref<3x3xi2>
+func.func @vector_store_i2_const_index_two_atomic_rmw(%arg0: vector<3xi2>) {
+ %src = memref.alloc() : memref<3x3xi2>
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
- vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
+ vector.store %arg0, %src[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
return
}
// In this example, emit 2 atomic RMWs.
-// Load from bit [12:18), byte [1:2] of total 3 bytes, both bytes needs rmw.
+//
+// Note, sizeof(%src) = 18 bits. This is modelled as %src_as_bytes:
+// <3xi8> (bits [0, 18) with the input values from %src, and [18, 24) are masked out)
-// CHECK-LABEL: func @vector_store_i2_const_index_two_atomic(
+// CHECK-LABEL: func @vector_store_i2_const_index_two_atomic_rmw(
// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]> : vector<4xi1>
// CHECK: %[[CST_0:.+]] = arith.constant dense<0> : vector<4xi2>
-// Part 1 atomic RMW sequence
+// Part 1 atomic RMW sequence (load bits [12, 16) from %src_as_bytes[1])
// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST_0]]
@@ -393,7 +395,7 @@ func.func @vector_store_i2_const_index_two_atomic(%arg0: vector<3xi2>) {
// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
-// Part 2 atomic RMW sequence
+// Part 2 atomic RMW sequence (load bits [16, 18) from %src_as_bytes[2])
// CHECK: %[[ADDR2:.+]] = arith.addi %[[C1]], %[[C1]] : index
// CHECK: %[[EXTRACT3:.+]] = vector.extract_strided_slice %[[ARG0]]
// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
@@ -411,7 +413,7 @@ func.func @vector_store_i2_const_index_two_atomic(%arg0: vector<3xi2>) {
// -----
-func.func @vector_store_i2_atomic(%arg0: vector<7xi2>) {
+func.func @vector_store_i2_atomic_rmw(%arg0: vector<7xi2>) {
%0 = memref.alloc() : memref<3x7xi2>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
@@ -420,7 +422,7 @@ func.func @vector_store_i2_atomic(%arg0: vector<7xi2>) {
}
// In this example, emit 2 atomic RMWs and 1 non-atomic store:
-// CHECK-LABEL: func @vector_store_i2_atomic(
+// CHECK-LABEL: func @vector_store_i2_atomic_rmw(
// CHECK-SAME: %[[ARG0:.+]]: vector<7xi2>)
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8>
// CHECK: %[[C1:.+]] = arith.constant 1 : index
@@ -467,7 +469,7 @@ func.func @vector_store_i2_atomic(%arg0: vector<7xi2>) {
// -----
-func.func @vector_store_i2_single_atomic(%arg0: vector<1xi2>) {
+func.func @vector_store_i2_const_index_one_atomic_rmw(%arg0: vector<1xi2>) {
%0 = memref.alloc() : memref<4x1xi2>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
@@ -476,7 +478,7 @@ func.func @vector_store_i2_single_atomic(%arg0: vector<1xi2>) {
}
// In this example, only emit 1 atomic store
-// CHECK-LABEL: func @vector_store_i2_single_atomic(
+// CHECK-LABEL: func @vector_store_i2_const_index_one_atomic_rmw(
// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>)
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8>
// CHECK: %[[C0:.+]] = arith.constant 0 : index
>From 974c2cbcfb00636d74dcb8229907ca0c546c2e83 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 14 Jan 2025 11:05:03 +0000
Subject: [PATCH 3/8] Address comments
---
.../Transforms/VectorEmulateNarrowType.cpp | 35 +++++++++++--------
1 file changed, 21 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 9838efdeb2cc22..2b783c1090156c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -305,12 +305,14 @@ static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc,
assert(
downcastType.getNumElements() * downcastType.getElementTypeBitWidth() ==
upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
- "expected upcastType size to be twice the size of downcastType");
- if (trueValue.getType() != downcastType)
+ "expected input and output number of bits to match");
+ if (trueValue.getType() != downcastType) {
trueValue = builder.create<vector::BitCastOp>(loc, downcastType, trueValue);
- if (falseValue.getType() != downcastType)
+ }
+ if (falseValue.getType() != downcastType) {
falseValue =
builder.create<vector::BitCastOp>(loc, downcastType, falseValue);
+ }
Value selectedType =
builder.create<arith::SelectOp>(loc, mask, trueValue, falseValue);
// Upcast the selected value to the new type.
@@ -454,28 +456,33 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
stridedMetadata.getConstifiedMixedStrides(),
getAsOpFoldResult(adaptor.getIndices()));
- auto foldedNumFrontPadElems =
+ std::optional<int64_t> foldedNumFrontPadElems =
isUnalignedEmulation
? getConstantIntValue(linearizedInfo.intraDataOffset)
: 0;
if (!foldedNumFrontPadElems) {
- // Unimplemented case for dynamic front padding size != 0
- return failure();
+ return failure("subbyte store emulation: dynamic front padding size is "
+ "not yet implemented");
}
- auto linearizedMemref = cast<MemRefValue>(adaptor.getBase());
+ auto memrefBase = cast<MemRefValue>(adaptor.getBase());
- // Shortcut: conditions when subbyte store at the front is not needed:
+ // Shortcut: conditions when subbyte emulated store at the front is not
+ // needed:
// 1. The source vector size is multiple of byte size
// 2. The address of the store is aligned to the emulated width boundary
+ //
+ // For example, to store a vector<4xi2> to <13xi2> at offset 4, does not
+ // need unaligned emulation because the store address is aligned and the
+ // source is a whole byte.
if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0) {
auto numElements = origElements / numSrcElemsPerDest;
auto bitCast = rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numElements, newElementType),
op.getValueToStore());
rewriter.replaceOpWithNewOp<vector::StoreOp>(
- op, bitCast.getResult(), linearizedMemref,
+ op, bitCast.getResult(), memrefBase,
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
return success();
}
@@ -511,7 +518,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
extractSliceIntoByte(rewriter, loc, valueToStore, 0,
frontSubWidthStoreElem, *foldedNumFrontPadElems);
- atomicStore(rewriter, loc, linearizedMemref, currentDestIndex,
+ atomicStore(rewriter, loc, memrefBase, currentDestIndex,
cast<VectorValue>(value), frontMask.getResult());
}
@@ -537,13 +544,13 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
numNonFullWidthElements);
auto originType = cast<VectorType>(fullWidthStorePart.getType());
- auto memrefElemType = getElementTypeOrSelf(linearizedMemref.getType());
+ auto memrefElemType = getElementTypeOrSelf(memrefBase.getType());
auto storeType = VectorType::get(
{originType.getNumElements() / numSrcElemsPerDest}, memrefElemType);
auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
fullWidthStorePart);
- rewriter.create<vector::StoreOp>(loc, bitCast.getResult(),
- linearizedMemref, currentDestIndex);
+ rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memrefBase,
+ currentDestIndex);
currentSourceIndex += numNonFullWidthElements;
currentDestIndex = rewriter.create<arith::AddIOp>(
@@ -565,7 +572,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto backMask = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
- atomicStore(rewriter, loc, linearizedMemref, currentDestIndex,
+ atomicStore(rewriter, loc, memrefBase, currentDestIndex,
cast<VectorValue>(subWidthStorePart), backMask.getResult());
}
>From cfe16dbcf32a1b4ebb2b50619f65ed36f8611a14 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 14 Jan 2025 11:11:55 +0000
Subject: [PATCH 4/8] Another update
---
mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 2 ++
1 file changed, 2 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 2b783c1090156c..e89caea03a8aec 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -377,6 +377,8 @@ static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
int64_t insertOffset) {
assert(vector.getType().getRank() == 1 && "expected 1-D vector");
auto vectorElementType = vector.getType().getElementType();
+ // TODO: update and use `alignedConversionPrecondition` in the place of
+ // these asserts.
assert(
sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&
"sliceNumElements * vector element size must be less than or equal to 8");
>From 149dedad08e0eec46feaf424d3dd76f455b680d4 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 15 Jan 2025 07:05:13 +0000
Subject: [PATCH 5/8] update again to address comments
---
.../Transforms/VectorEmulateNarrowType.cpp | 108 ++++++++++++------
.../vector-emulate-narrow-type-unaligned.mlir | 12 +-
2 files changed, 79 insertions(+), 41 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index e89caea03a8aec..d24ee95e75ca7f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -400,6 +400,9 @@ namespace {
// ConvertVectorStore
//===----------------------------------------------------------------------===//
+///
+///
+
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
using OpConversionPattern::OpConversionPattern;
@@ -443,7 +446,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// vector<4xi8>
auto origElements = valueToStore.getType().getNumElements();
- bool isUnalignedEmulation = origElements % numSrcElemsPerDest != 0;
+ bool isAlignedEmulation = origElements % numSrcElemsPerDest == 0;
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -459,9 +462,9 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
getAsOpFoldResult(adaptor.getIndices()));
std::optional<int64_t> foldedNumFrontPadElems =
- isUnalignedEmulation
- ? getConstantIntValue(linearizedInfo.intraDataOffset)
- : 0;
+ isAlignedEmulation
+ ? 0
+ : getConstantIntValue(linearizedInfo.intraDataOffset);
if (!foldedNumFrontPadElems) {
return failure("subbyte store emulation: dynamic front padding size is "
@@ -472,13 +475,13 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// Shortcut: conditions when subbyte emulated store at the front is not
// needed:
- // 1. The source vector size is multiple of byte size
- // 2. The address of the store is aligned to the emulated width boundary
+ // 1. The source vector size (in bits) is a multiple of byte size.
+ // 2. The address of the store is aligned to the emulated width boundary.
//
// For example, to store a vector<4xi2> to <13xi2> at offset 4, does not
// need unaligned emulation because the store address is aligned and the
// source is a whole byte.
- if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0) {
+ if (isAlignedEmulation && *foldedNumFrontPadElems == 0) {
auto numElements = origElements / numSrcElemsPerDest;
auto bitCast = rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numElements, newElementType),
@@ -489,17 +492,50 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
return success();
}
- // The index into the target memref we are storing to
+ // Next, handle the case when sub-byte read-modify-write
+ // sequences are needed to emulate a vector store.
+ // Here is an example:
+ //
+ // Vector to store: vector<7xi2>
+ // Value to store: 11 11 11 11 11 11 11 (all ones)
+ //
+ // Destination: memref<12xi2>
+ // Store offset: 2 (i.e. 4 bits into the 1st emulated byte).
+ //
+ // MLIR: vector.store %val, %dest[%c2] : memref<12xi2>, vector<7xi2>
+ //
+ // Destination memref before:
+ //
+ // Byte 0 Byte 1 Byte 2
+ // +----------+----------+----------+
+ // | 00000000 | 00000000 | 00000000 |
+ // +----------+----------+----------+
+ //
+ // Destination memref after:
+ //
+ // Byte 0 Byte 1 Byte 2
+ // +----------+----------+----------+
+ // | 00001111 | 11111111 | 11000000 |
+ // +----------+----------+----------+
+ //
+ // Note, stores to Byte 1 are "full-width" and hence don't require RMW (no
+ // need for atomicity). Stores to Bytes 0 and Byte 2 are "partial", hence
+ // requiring RMW access (atomicity is required).
+
+ // The index into the target memref we are storing to.
Value currentDestIndex =
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
+ // The index into the source vector we are currently processing.
+ auto currentSourceIndex = 0;
+
+ // Build a mask used for rmw.
auto subWidthStoreMaskType =
VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type());
- // The index into the source vector we are currently processing
- auto currentSourceIndex = 0;
- // 1. Partial width store for the first byte, when the store address is not
- // aligned to emulated width boundary, deal with the unaligned part so that
- // the rest elements are aligned to width boundary.
+ // 1. Partial width store for the leading byte.
+ // When the store address is not aligned to emulated width boundary, deal
+ // with the unaligned part so that the rest elements are aligned to width
+ // boundary.
auto frontSubWidthStoreElem =
(numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
if (frontSubWidthStoreElem > 0) {
@@ -535,8 +571,9 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
currentDestIndex = rewriter.create<arith::AddIOp>(
loc, rewriter.getIndexType(), currentDestIndex, constantOne);
- // 2. Full width store. After the previous step, the store address is
- // aligned to the emulated width boundary.
+ // 2. Full width store for the inner output bytes.
+ // After the previous step, the store address is aligned to the emulated
+ // width boundary.
int64_t fullWidthStoreSize =
(origElements - currentSourceIndex) / numSrcElemsPerDest;
int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
@@ -560,15 +597,16 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
rewriter.create<arith::ConstantIndexOp>(loc, fullWidthStoreSize));
}
- // 3. Deal with trailing elements that are aligned to the emulated width,
- // but their length is smaller than the emulated width.
+ // 3. Partial width store for the trailing output byte.
+ // It is needed when the residual length is smaller than the emulated width,
+ // which is not covered in step 2 above.
auto remainingElements = origElements - currentSourceIndex;
if (remainingElements != 0) {
auto subWidthStorePart =
extractSliceIntoByte(rewriter, loc, cast<VectorValue>(valueToStore),
currentSourceIndex, remainingElements, 0);
- // Generate back mask
+ // Generate back mask.
auto maskValues = SmallVector<bool>(numSrcElemsPerDest, 0);
std::fill_n(maskValues.begin(), remainingElements, 1);
auto backMask = rewriter.create<arith::ConstantOp>(
@@ -751,7 +789,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
// compile time as they must be constants.
auto origElements = op.getVectorType().getNumElements();
- bool isUnalignedEmulation = origElements % scale != 0;
+ bool isAlignedEmulation = origElements % scale == 0;
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -767,9 +805,9 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
getAsOpFoldResult(adaptor.getIndices()));
std::optional<int64_t> foldedIntraVectorOffset =
- isUnalignedEmulation
- ? getConstantIntValue(linearizedInfo.intraDataOffset)
- : 0;
+ isAlignedEmulation
+ ? 0
+ : getConstantIntValue(linearizedInfo.intraDataOffset);
// Always load enough elements which can cover the original elements.
int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
@@ -785,7 +823,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
result = dynamicallyExtractSubVector(
rewriter, loc, cast<VectorValue>(result), resultVector,
linearizedInfo.intraDataOffset, origElements);
- } else if (isUnalignedEmulation) {
+ } else if (!isAlignedEmulation) {
result = staticallyExtractSubvector(
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
@@ -867,7 +905,7 @@ struct ConvertVectorMaskedLoad final
// subvector at the proper offset after bit-casting.
auto origType = op.getVectorType();
auto origElements = origType.getNumElements();
- bool isUnalignedEmulation = origElements % scale != 0;
+ bool isAlignedEmulation = origElements % scale == 0;
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -882,9 +920,9 @@ struct ConvertVectorMaskedLoad final
getAsOpFoldResult(adaptor.getIndices()));
std::optional<int64_t> foldedIntraVectorOffset =
- isUnalignedEmulation
- ? getConstantIntValue(linearizedInfo.intraDataOffset)
- : 0;
+ isAlignedEmulation
+ ? 0
+ : getConstantIntValue(linearizedInfo.intraDataOffset);
int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
FailureOr<Operation *> newMask = getCompressedMaskOp(
@@ -905,7 +943,7 @@ struct ConvertVectorMaskedLoad final
passthru = dynamicallyInsertSubVector(
rewriter, loc, cast<VectorValue>(passthru), emptyVector,
linearizedInfo.intraDataOffset, origElements);
- } else if (isUnalignedEmulation) {
+ } else if (!isAlignedEmulation) {
passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
*foldedIntraVectorOffset);
}
@@ -933,7 +971,7 @@ struct ConvertVectorMaskedLoad final
mask = dynamicallyInsertSubVector(
rewriter, loc, cast<VectorValue>(mask), emptyMask,
linearizedInfo.intraDataOffset, origElements);
- } else if (isUnalignedEmulation) {
+ } else if (!isAlignedEmulation) {
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
*foldedIntraVectorOffset);
}
@@ -944,7 +982,7 @@ struct ConvertVectorMaskedLoad final
result = dynamicallyExtractSubVector(
rewriter, loc, cast<VectorValue>(result), op.getPassThru(),
linearizedInfo.intraDataOffset, origElements);
- } else if (isUnalignedEmulation) {
+ } else if (!isAlignedEmulation) {
result = staticallyExtractSubvector(
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
@@ -986,7 +1024,7 @@ struct ConvertVectorTransferRead final
auto origElements = op.getVectorType().getNumElements();
- bool isUnalignedEmulation = origElements % scale != 0;
+ bool isAlignedEmulation = origElements % scale == 0;
auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
adaptor.getPadding());
@@ -1005,9 +1043,9 @@ struct ConvertVectorTransferRead final
getAsOpFoldResult(adaptor.getIndices()));
std::optional<int64_t> foldedIntraVectorOffset =
- isUnalignedEmulation
- ? getConstantIntValue(linearizedInfo.intraDataOffset)
- : 0;
+ isAlignedEmulation
+ ? 0
+ : getConstantIntValue(linearizedInfo.intraDataOffset);
int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
auto numElements =
@@ -1028,7 +1066,7 @@ struct ConvertVectorTransferRead final
result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
linearizedInfo.intraDataOffset,
origElements);
- } else if (isUnalignedEmulation) {
+ } else if (!isAlignedEmulation) {
result = staticallyExtractSubvector(
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index a80ab7b7e4166e..68d2bd99c201bb 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -361,7 +361,7 @@ func.func @vector_maskedload_i2_constant_mask_unaligned(%passthru: vector<5xi2>)
/// vector.store
///----------------------------------------------------------------------------------------
-func.func @vector_store_i2_const_index_two_atomic_rmw(%arg0: vector<3xi2>) {
+func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) {
%src = memref.alloc() : memref<3x3xi2>
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
@@ -374,7 +374,7 @@ func.func @vector_store_i2_const_index_two_atomic_rmw(%arg0: vector<3xi2>) {
// Note, sizeof(%src) = 18 bits. This is modelled as %src_as_bytes:
// <3xi8> (bits [0, 18) with the input values from %src, and [18, 24) are masked out)
-// CHECK-LABEL: func @vector_store_i2_const_index_two_atomic_rmw(
+// CHECK-LABEL: func @vector_store_i2_const_index_two_partial_stores(
// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
// CHECK: %[[C1:.+]] = arith.constant 1 : index
@@ -413,7 +413,7 @@ func.func @vector_store_i2_const_index_two_atomic_rmw(%arg0: vector<3xi2>) {
// -----
-func.func @vector_store_i2_atomic_rmw(%arg0: vector<7xi2>) {
+func.func @vector_store_i2_two_partial_one_full_stores(%arg0: vector<7xi2>) {
%0 = memref.alloc() : memref<3x7xi2>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
@@ -422,7 +422,7 @@ func.func @vector_store_i2_atomic_rmw(%arg0: vector<7xi2>) {
}
// In this example, emit 2 atomic RMWs and 1 non-atomic store:
-// CHECK-LABEL: func @vector_store_i2_atomic_rmw(
+// CHECK-LABEL: func @vector_store_i2_two_partial_one_full_stores(
// CHECK-SAME: %[[ARG0:.+]]: vector<7xi2>)
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8>
// CHECK: %[[C1:.+]] = arith.constant 1 : index
@@ -469,7 +469,7 @@ func.func @vector_store_i2_atomic_rmw(%arg0: vector<7xi2>) {
// -----
-func.func @vector_store_i2_const_index_one_atomic_rmw(%arg0: vector<1xi2>) {
+func.func @vector_store_i2_const_index_one_partial_store(%arg0: vector<1xi2>) {
%0 = memref.alloc() : memref<4x1xi2>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
@@ -478,7 +478,7 @@ func.func @vector_store_i2_const_index_one_atomic_rmw(%arg0: vector<1xi2>) {
}
// In this example, only emit 1 atomic store
-// CHECK-LABEL: func @vector_store_i2_const_index_one_atomic_rmw(
+// CHECK-LABEL: func @vector_store_i2_const_index_one_partial_store(
// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>)
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8>
// CHECK: %[[C0:.+]] = arith.constant 0 : index
>From a6ec095b18f5ba08141d599fc0c579fbf3732211 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Thu, 16 Jan 2025 16:46:14 +0000
Subject: [PATCH 6/8] final touch
---
mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 3 ---
1 file changed, 3 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index d24ee95e75ca7f..d0deae2c476df3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -400,9 +400,6 @@ namespace {
// ConvertVectorStore
//===----------------------------------------------------------------------===//
-///
-///
-
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
using OpConversionPattern::OpConversionPattern;
>From 50f0786d1e727fea5ade9d7ea1025f74681cc353 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Fri, 17 Jan 2025 04:58:42 +0000
Subject: [PATCH 7/8] updates about comments
---
.../Transforms/VectorEmulateNarrowType.cpp | 20 +++++++++----------
1 file changed, 10 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index d0deae2c476df3..0225650027b685 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -403,9 +403,6 @@ namespace {
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
using OpConversionPattern::OpConversionPattern;
- ConvertVectorStore(MLIRContext *context, bool useAtomicWrites)
- : OpConversionPattern<vector::StoreOp>(context) {}
-
LogicalResult
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -416,10 +413,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
"only 1-D vectors are supported ATM");
auto loc = op.getLoc();
- auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
auto valueToStore = cast<VectorValue>(op.getValueToStore());
auto oldElementType = valueToStore.getType().getElementType();
- auto newElementType = convertedType.getElementType();
+ auto newElementType =
+ cast<MemRefType>(adaptor.getBase().getType()).getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
int dstBits = newElementType.getIntOrFloatBitWidth();
@@ -464,21 +461,24 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
: getConstantIntValue(linearizedInfo.intraDataOffset);
if (!foldedNumFrontPadElems) {
- return failure("subbyte store emulation: dynamic front padding size is "
- "not yet implemented");
+ return rewriter.notifyMatchFailure(
+ op, "subbyte store emulation: dynamic front padding size is "
+ "not yet implemented");
}
auto memrefBase = cast<MemRefValue>(adaptor.getBase());
- // Shortcut: conditions when subbyte emulated store at the front is not
- // needed:
+ // Conditions when subbyte emulated store is not needed:
// 1. The source vector size (in bits) is a multiple of byte size.
// 2. The address of the store is aligned to the emulated width boundary.
//
// For example, to store a vector<4xi2> to <13xi2> at offset 4, does not
// need unaligned emulation because the store address is aligned and the
// source is a whole byte.
- if (isAlignedEmulation && *foldedNumFrontPadElems == 0) {
+ bool emulationRequiresPartialStores =
+ !isAlignedEmulation || *foldedNumFrontPadElems != 0;
+ if (!emulationRequiresPartialStores) {
+ // Basic case: storing full bytes.
auto numElements = origElements / numSrcElemsPerDest;
auto bitCast = rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numElements, newElementType),
>From d458e3b3ac97766c01fe68a7184503c1801bd044 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Fri, 24 Jan 2025 11:12:43 +0000
Subject: [PATCH 8/8] another update according to comments
---
.../Transforms/VectorEmulateNarrowType.cpp | 28 ++++++++++---------
1 file changed, 15 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 0225650027b685..7ca88f1e0a0df9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -240,9 +240,10 @@ static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc,
/// function emits multiple `vector.extract` and `vector.insert` ops, so only
/// use it when `offset` cannot be folded into a constant value.
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
- VectorValue source, Value dest,
+ Value source, Value dest,
OpFoldResult offset,
int64_t numElementsToExtract) {
+ assert(isa<VectorValue>(source) && "expected `source` to be a vector type");
for (int i = 0; i < numElementsToExtract; ++i) {
Value extractLoc =
(i == 0) ? offset.dyn_cast<Value>()
@@ -258,9 +259,10 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
/// Inserts a 1-D subvector into a 1-D `dest` vector at index `destOffsetVar`.
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
- VectorValue source, Value dest,
+ Value source, Value dest,
OpFoldResult destOffsetVar,
size_t length) {
+ assert(isa<VectorValue>(source) && "expected `source` to be a vector type");
assert(length > 0 && "length must be greater than 0");
Value destOffsetVal =
getValueOrCreateConstantIndexOp(rewriter, loc, destOffsetVar);
@@ -468,7 +470,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto memrefBase = cast<MemRefValue>(adaptor.getBase());
- // Conditions when subbyte emulated store is not needed:
+ // Conditions when atomic RMWs are not needed:
// 1. The source vector size (in bits) is a multiple of byte size.
// 2. The address of the store is aligned to the emulated width boundary.
//
@@ -499,7 +501,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// Destination: memref<12xi2>
// Store offset: 2 (i.e. 4 bits into the 1st emulated byte).
//
- // MLIR: vector.store %val, %dest[%c2] : memref<12xi2>, vector<7xi2>
+ // Input MLIR: vector.store %val, %dest[%c2] : memref<12xi2>, vector<7xi2>
//
// Destination memref before:
//
@@ -817,9 +819,9 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
if (!foldedIntraVectorOffset) {
auto resultVector = rewriter.create<arith::ConstantOp>(
loc, op.getType(), rewriter.getZeroAttr(op.getType()));
- result = dynamicallyExtractSubVector(
- rewriter, loc, cast<VectorValue>(result), resultVector,
- linearizedInfo.intraDataOffset, origElements);
+ result = dynamicallyExtractSubVector(rewriter, loc, result, resultVector,
+ linearizedInfo.intraDataOffset,
+ origElements);
} else if (!isAlignedEmulation) {
result = staticallyExtractSubvector(
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
@@ -938,8 +940,8 @@ struct ConvertVectorMaskedLoad final
loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
if (!foldedIntraVectorOffset) {
passthru = dynamicallyInsertSubVector(
- rewriter, loc, cast<VectorValue>(passthru), emptyVector,
- linearizedInfo.intraDataOffset, origElements);
+ rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset,
+ origElements);
} else if (!isAlignedEmulation) {
passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
*foldedIntraVectorOffset);
@@ -965,9 +967,9 @@ struct ConvertVectorMaskedLoad final
auto emptyMask = rewriter.create<arith::ConstantOp>(
loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
if (!foldedIntraVectorOffset) {
- mask = dynamicallyInsertSubVector(
- rewriter, loc, cast<VectorValue>(mask), emptyMask,
- linearizedInfo.intraDataOffset, origElements);
+ mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask,
+ linearizedInfo.intraDataOffset,
+ origElements);
} else if (!isAlignedEmulation) {
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
*foldedIntraVectorOffset);
@@ -977,7 +979,7 @@ struct ConvertVectorMaskedLoad final
rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
if (!foldedIntraVectorOffset) {
result = dynamicallyExtractSubVector(
- rewriter, loc, cast<VectorValue>(result), op.getPassThru(),
+ rewriter, loc, result, op.getPassThru(),
linearizedInfo.intraDataOffset, origElements);
} else if (!isAlignedEmulation) {
result = staticallyExtractSubvector(
More information about the Mlir-commits
mailing list