[Mlir-commits] [mlir] [MLIR] Implement emulation of static indexing subbyte type vector stores (PR #115922)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 7 07:46:07 PST 2025
https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/115922
>From 84d977bd91231b84c09b54e1052de4ba21bac57d Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Fri, 25 Oct 2024 15:19:42 +0000
Subject: [PATCH] [MLIR] Implement emulation of static indexing subbyte type
vector stores
This patch enables unaligned, statically indexed storing of vectors with sub emulation width element types. By default, this patch ensures atomicity by using atomic read-modify-write sequence in places where data contention might happen. The caller function is able to avoid emitting atomic operations by setting `useAtomicWrites` to false when calling `populateVectorNarrowTypeEmulationPatterns`. In such case, a regular rmw sequence is emitted for efficiency.
To illustrate the mechanism, consider the example of storing `vector<7xi2>` into `memref<3x7xi2>[1, 0]`. In this case the linearized indices of those bits being overwritten are `[14, 28)`, which are:
* the last 2 bits of byte no.2
* byte no.3
* first 4 bits of byte no.4
Because memory accesses are in bytes, byte no.2 and no.4 in the above example are only being modified partially. In the case of multi-threading scenario, in order to avoid data contention, these two bytes must be handled atomically. However, if the caller of the pass sees that such is not a concern, then the atomic operations will be reduced into a generic rmw sequence.
---
.../Vector/Transforms/VectorRewritePatterns.h | 6 +-
.../Transforms/VectorEmulateNarrowType.cpp | 335 +++++++++++++++---
...late-narrow-type-unaligned-non-atomic.mlir | 119 +++++++
.../vector-emulate-narrow-type-unaligned.mlir | 137 +++++++
.../Dialect/MemRef/TestEmulateNarrowType.cpp | 8 +-
5 files changed, 554 insertions(+), 51 deletions(-)
create mode 100644 mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index a59f06f3c1ef1b..43478aacb50a14 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -364,10 +364,12 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
/// Appends patterns for emulating vector operations over narrow types with ops
-/// over wider types.
+/// over wider types. The `useAtomicWrites` indicates whether to use
+/// op `memref.generic_atomic_rmw` to perform atomic subbyte storing, or just a
+/// rmw sequence otherwise.
void populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
- RewritePatternSet &patterns);
+ RewritePatternSet &patterns, bool useAtomicWrites = true);
/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
/// vector operations comprising `shuffle` and `bitwise` ops.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 181c394edc1d20..919f84210f343d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -33,6 +33,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
#include <cstdint>
@@ -45,6 +46,9 @@ using namespace mlir;
#define DBGSNL() (llvm::dbgs() << "\n")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+using VectorValue = TypedValue<VectorType>;
+using MemRefValue = TypedValue<MemRefType>;
+
/// Returns a compressed mask for the emulated vector. For example, when
/// emulating an eight-element `i8` vector with `i32` (i.e. when the source
/// elements span two dest elements), this method compresses `vector<8xi1>`
@@ -194,13 +198,10 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
/// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
/// emitting `vector.extract_strided_slice`.
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
- VectorType extractType, Value source,
- int64_t frontOffset,
+ Value source, int64_t frontOffset,
int64_t subvecSize) {
auto vectorType = cast<VectorType>(source.getType());
- assert((vectorType.getRank() == 1 && extractType.getRank() == 1) &&
- "expected 1-D source and destination types");
- (void)vectorType;
+ assert(vectorType.getRank() == 1 && "expected 1-D source types");
assert(frontOffset + subvecSize <= vectorType.getNumElements() &&
"subvector out of bounds");
@@ -211,9 +212,12 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
auto offsets = rewriter.getI64ArrayAttr({frontOffset});
auto sizes = rewriter.getI64ArrayAttr({subvecSize});
auto strides = rewriter.getI64ArrayAttr({1});
+
+ auto resultVectorType =
+ VectorType::get({subvecSize}, vectorType.getElementType());
return rewriter
- .create<vector::ExtractStridedSliceOp>(loc, extractType, source, offsets,
- sizes, strides)
+ .create<vector::ExtractStridedSliceOp>(loc, resultVectorType, source,
+ offsets, sizes, strides)
->getResult(0);
}
@@ -237,8 +241,8 @@ static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc,
/// function emits multiple `vector.extract` and `vector.insert` ops, so only
/// use it when `offset` cannot be folded into a constant value.
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
- TypedValue<VectorType> source,
- Value dest, OpFoldResult offset,
+ VectorValue source, Value dest,
+ OpFoldResult offset,
int64_t numElementsToExtract) {
for (int i = 0; i < numElementsToExtract; ++i) {
Value extractLoc =
@@ -255,8 +259,8 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
/// Inserts a 1-D subvector into a 1-D `dest` vector at index `destOffsetVar`.
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
- TypedValue<VectorType> source,
- Value dest, OpFoldResult destOffsetVar,
+ VectorValue source, Value dest,
+ OpFoldResult destOffsetVar,
size_t length) {
assert(length > 0 && "length must be greater than 0");
Value destOffsetVal =
@@ -277,11 +281,12 @@ static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
/// specifically, use `emulatedElemType` for loading a vector of `origElemType`.
/// The load location is given by `base` and `linearizedIndices`, and the
/// load size is given by `numEmulatedElementsToLoad`.
-static TypedValue<VectorType>
-emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
- OpFoldResult linearizedIndices,
- int64_t numEmultedElementsToLoad, Type origElemType,
- Type emulatedElemType) {
+static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc,
+ Value base,
+ OpFoldResult linearizedIndices,
+ int64_t numEmultedElementsToLoad,
+ Type origElemType,
+ Type emulatedElemType) {
auto scale = emulatedElemType.getIntOrFloatBitWidth() /
origElemType.getIntOrFloatBitWidth();
auto newLoad = rewriter.create<vector::LoadOp>(
@@ -292,6 +297,106 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
newLoad);
}
+/// Emits `memref.generic_atomic_rmw` op to store a subbyte-sized value to a
+/// byte in memory, with a mask. The `valueToStore` is a vector of subbyte-sized
+/// elements, with size of 8 bits, and the mask is used to select which elements
+/// to store.
+///
+/// Inputs:
+/// linearizedMemref = |2|2|2|2| : <4xi2> (<1xi8>)
+/// linearizedIndex = 2
+/// valueToStore = |3|3|3|3| : vector<4xi2>
+/// mask = |0|0|1|1| : vector<4xi1>
+///
+/// Result:
+/// linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
+static void atomicStore(OpBuilder &builder, Location loc,
+ MemRefValue linearizedMemref, Value linearizedIndex,
+ VectorValue valueToStore, Value mask) {
+ assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
+
+ // Create an atomic load-modify-write region using
+ // `memref.generic_atomic_rmw`.
+ auto atomicOp = builder.create<memref::GenericAtomicRMWOp>(
+ loc, linearizedMemref, ValueRange{linearizedIndex});
+ Value origValue = atomicOp.getCurrentValue();
+
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(atomicOp.getBody());
+
+ // Load the original value from memory, and cast it to the original element
+ // type.
+ auto oneElemVecType = VectorType::get({1}, origValue.getType());
+ Value origVecValue = builder.create<vector::FromElementsOp>(
+ loc, oneElemVecType, ValueRange{origValue});
+ origVecValue = builder.create<vector::BitCastOp>(loc, valueToStore.getType(),
+ origVecValue);
+
+ // Construct the final masked value and yield it.
+ Value maskedValue =
+ builder.create<arith::SelectOp>(loc, mask, valueToStore, origVecValue);
+ maskedValue =
+ builder.create<vector::BitCastOp>(loc, oneElemVecType, maskedValue);
+ auto scalarMaskedValue =
+ builder.create<vector::ExtractOp>(loc, maskedValue, 0);
+ builder.create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
+}
+
+/// Generate a non-atomic read-modify-write sequence for subbyte storing.
+/// It has similar logic to `atomicStore`, but without the atomicity.
+static void rmwStore(OpBuilder &builder, Location loc,
+ MemRefValue linearizedMemref, Value linearizedIndex,
+ VectorValue valueToStore, Value mask) {
+ assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
+
+ // Load the original value from memory, and cast it to the original element
+ // type.
+ auto oneElemVecType =
+ VectorType::get({1}, linearizedMemref.getType().getElementType());
+ Value origVecValue = builder.create<vector::LoadOp>(
+ loc, oneElemVecType, linearizedMemref, ValueRange{linearizedIndex});
+ origVecValue = builder.create<vector::BitCastOp>(loc, valueToStore.getType(),
+ origVecValue);
+
+ // Construct the final masked value and yield it.
+ Value maskedValue =
+ builder.create<arith::SelectOp>(loc, mask, origVecValue, valueToStore);
+ maskedValue =
+ builder.create<vector::BitCastOp>(loc, oneElemVecType, maskedValue);
+ builder.create<vector::StoreOp>(loc, maskedValue, linearizedMemref,
+ linearizedIndex);
+}
+
+/// Extract `sliceNumElements` from source `vector` at `sliceOffset`,
+/// and insert it into an empty vector at offset `byteOffset`.
+/// Inputs:
+/// vector = |1|2|3|4| : vector<4xi2>
+/// sliceOffset = 1
+/// sliceNumElements = 2
+/// byteOffset = 2
+/// Output:
+/// vector = |0|0|2|3| : vector<4xi2>
+static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
+ Location loc, VectorValue vector,
+ int64_t sliceOffset, int64_t sliceNumElements,
+ int64_t byteOffset) {
+ assert(vector.getType().getRank() == 1 && "expected 1-D vector");
+ auto vectorElementType = vector.getType().getElementType();
+ assert(
+ sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&
+ "sliceNumElements * vector element size must be less than or equal to 8");
+ assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
+ "vector element must be a valid sub-byte type");
+ auto scale = 8 / vectorElementType.getIntOrFloatBitWidth();
+ auto emptyByteVector = rewriter.create<arith::ConstantOp>(
+ loc, VectorType::get({scale}, vectorElementType),
+ rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType)));
+ auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
+ sliceOffset, sliceNumElements);
+ return staticallyInsertSubvector(rewriter, loc, extracted, emptyByteVector,
+ byteOffset);
+}
+
namespace {
//===----------------------------------------------------------------------===//
@@ -301,6 +406,10 @@ namespace {
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
using OpConversionPattern::OpConversionPattern;
+ ConvertVectorStore(MLIRContext *context, bool useAtomicWrites)
+ : OpConversionPattern<vector::StoreOp>(context),
+ useAtomicWrites_(useAtomicWrites) {}
+
LogicalResult
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -312,8 +421,9 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto loc = op.getLoc();
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
- Type oldElementType = op.getValueToStore().getType().getElementType();
- Type newElementType = convertedType.getElementType();
+ auto valueToStore = cast<VectorValue>(op.getValueToStore());
+ auto oldElementType = valueToStore.getType().getElementType();
+ auto newElementType = convertedType.getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
int dstBits = newElementType.getIntOrFloatBitWidth();
@@ -321,7 +431,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
return rewriter.notifyMatchFailure(
op, "only dstBits % srcBits == 0 supported");
}
- int scale = dstBits / srcBits;
+ int numSrcElemsPerDest = dstBits / srcBits;
// Adjust the number of elements to store when emulating narrow types.
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -336,15 +446,15 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
// vector<4xi8>
- auto origElements = op.getValueToStore().getType().getNumElements();
- if (origElements % scale != 0)
- return failure();
+ auto origElements = valueToStore.getType().getNumElements();
+ bool isUnalignedEmulation = origElements % numSrcElemsPerDest != 0;
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
OpFoldResult linearizedIndices;
- std::tie(std::ignore, linearizedIndices) =
+ memref::LinearizedMemRefInfo linearizedInfo;
+ std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
rewriter, loc, srcBits, dstBits,
stridedMetadata.getConstifiedMixedOffset(),
@@ -352,16 +462,142 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
stridedMetadata.getConstifiedMixedStrides(),
getAsOpFoldResult(adaptor.getIndices()));
- auto numElements = origElements / scale;
- auto bitCast = rewriter.create<vector::BitCastOp>(
- loc, VectorType::get(numElements, newElementType),
- op.getValueToStore());
+ auto foldedNumFrontPadElems =
+ isUnalignedEmulation
+ ? getConstantIntValue(linearizedInfo.intraDataOffset)
+ : 0;
+
+ if (!foldedNumFrontPadElems) {
+ // Unimplemented case for dynamic front padding size != 0
+ return failure();
+ }
+
+ auto linearizedMemref = cast<MemRefValue>(adaptor.getBase());
+
+ // Shortcut: conditions when subbyte store at the front is not needed:
+ // 1. The source vector size is multiple of byte size
+ // 2. The address of the store is aligned to the emulated width boundary
+ if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0) {
+ auto numElements = origElements / numSrcElemsPerDest;
+ auto bitCast = rewriter.create<vector::BitCastOp>(
+ loc, VectorType::get(numElements, newElementType),
+ op.getValueToStore());
+ rewriter.replaceOpWithNewOp<vector::StoreOp>(
+ op, bitCast.getResult(), linearizedMemref,
+ getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+ return success();
+ }
+
+ // The index into the target memref we are storing to
+ Value currentDestIndex =
+ getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
+ auto subWidthStoreMaskType =
+ VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type());
+ // The index into the source vector we are currently processing
+ auto currentSourceIndex = 0;
+
+ // 1. Partial width store for the first byte, when the store address is not
+ // aligned to emulated width boundary, deal with the unaligned part so that
+ // the rest elements are aligned to width boundary.
+ auto frontSubWidthStoreElem =
+ (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
+ if (frontSubWidthStoreElem > 0) {
+ SmallVector<bool> frontMaskValues(numSrcElemsPerDest, false);
+ if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
+ std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
+ origElements, true);
+ frontSubWidthStoreElem = origElements;
+ } else {
+ std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
+ *foldedNumFrontPadElems, true);
+ }
+ auto frontMask = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
+
+ currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
+ auto value =
+ extractSliceIntoByte(rewriter, loc, valueToStore, 0,
+ frontSubWidthStoreElem, *foldedNumFrontPadElems);
+
+ subEmulatedWidthStore(rewriter, loc, linearizedMemref, currentDestIndex,
+ cast<VectorValue>(value), frontMask.getResult());
+ }
+
+ if (currentSourceIndex >= origElements) {
+ rewriter.eraseOp(op);
+ return success();
+ }
- rewriter.replaceOpWithNewOp<vector::StoreOp>(
- op, bitCast.getResult(), adaptor.getBase(),
- getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+ // Increment the destination index by 1 to align to the emulated width
+ // boundary.
+ auto constantOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ currentDestIndex = rewriter.create<arith::AddIOp>(
+ loc, rewriter.getIndexType(), currentDestIndex, constantOne);
+
+ // 2. Full width store. After the previous step, the store address is
+ // aligned to the emulated width boundary.
+ int64_t fullWidthStoreSize =
+ (origElements - currentSourceIndex) / numSrcElemsPerDest;
+ int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
+ if (fullWidthStoreSize > 0) {
+ auto fullWidthStorePart = staticallyExtractSubvector(
+ rewriter, loc, valueToStore, currentSourceIndex,
+ numNonFullWidthElements);
+
+ auto originType = cast<VectorType>(fullWidthStorePart.getType());
+ auto memrefElemType = getElementTypeOrSelf(linearizedMemref.getType());
+ auto storeType = VectorType::get(
+ {originType.getNumElements() / numSrcElemsPerDest}, memrefElemType);
+ auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
+ fullWidthStorePart);
+ rewriter.create<vector::StoreOp>(loc, bitCast.getResult(),
+ linearizedMemref, currentDestIndex);
+
+ currentSourceIndex += numNonFullWidthElements;
+ currentDestIndex = rewriter.create<arith::AddIOp>(
+ loc, rewriter.getIndexType(), currentDestIndex,
+ rewriter.create<arith::ConstantIndexOp>(loc, fullWidthStoreSize));
+ }
+
+ // 3. Deal with trailing elements that are aligned to the emulated width,
+ // but their length is smaller than the emulated width.
+ auto remainingElements = origElements - currentSourceIndex;
+ if (remainingElements != 0) {
+ auto subWidthStorePart =
+ extractSliceIntoByte(rewriter, loc, cast<VectorValue>(valueToStore),
+ currentSourceIndex, remainingElements, 0);
+
+ // Generate back mask
+ auto maskValues = SmallVector<bool>(numSrcElemsPerDest, 0);
+ std::fill_n(maskValues.begin(), remainingElements, 1);
+ auto backMask = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
+
+ subEmulatedWidthStore(rewriter, loc, linearizedMemref, currentDestIndex,
+ cast<VectorValue>(subWidthStorePart),
+ backMask.getResult());
+ }
+
+ rewriter.eraseOp(op);
return success();
}
+
+ /// Store a subbyte-sized value to memory, with a mask. Depending on the
+ /// configuration, it could be an atomic store or an RMW sequence.
+ template <typename... Args>
+ void subEmulatedWidthStore(Args &&...args) const {
+ static_assert(
+ std::is_same_v<decltype(atomicStore), decltype(rmwStore)> &&
+ "`atomicStore` and `rmwStore` must have same signature, as per "
+ "the design to keep the code clean, which one to call is "
+ "determined by the `useAtomicWrites` flag.");
+ std::function<decltype(atomicStore)> storeFunc =
+ useAtomicWrites_ ? atomicStore : rmwStore;
+ storeFunc(std::forward<Args>(args)...);
+ }
+
+private:
+ const bool useAtomicWrites_;
};
//===----------------------------------------------------------------------===//
@@ -564,12 +800,11 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
auto resultVector = rewriter.create<arith::ConstantOp>(
loc, op.getType(), rewriter.getZeroAttr(op.getType()));
result = dynamicallyExtractSubVector(
- rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
+ rewriter, loc, cast<VectorValue>(result), resultVector,
linearizedInfo.intraDataOffset, origElements);
} else if (isUnalignedEmulation) {
- result =
- staticallyExtractSubvector(rewriter, loc, op.getType(), result,
- *foldedIntraVectorOffset, origElements);
+ result = staticallyExtractSubvector(
+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
rewriter.replaceOp(op, result);
return success();
@@ -685,8 +920,8 @@ struct ConvertVectorMaskedLoad final
loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
if (!foldedIntraVectorOffset) {
passthru = dynamicallyInsertSubVector(
- rewriter, loc, dyn_cast<TypedValue<VectorType>>(passthru),
- emptyVector, linearizedInfo.intraDataOffset, origElements);
+ rewriter, loc, cast<VectorValue>(passthru), emptyVector,
+ linearizedInfo.intraDataOffset, origElements);
} else if (isUnalignedEmulation) {
passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
*foldedIntraVectorOffset);
@@ -713,7 +948,7 @@ struct ConvertVectorMaskedLoad final
loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
if (!foldedIntraVectorOffset) {
mask = dynamicallyInsertSubVector(
- rewriter, loc, dyn_cast<TypedValue<VectorType>>(mask), emptyMask,
+ rewriter, loc, cast<VectorValue>(mask), emptyMask,
linearizedInfo.intraDataOffset, origElements);
} else if (isUnalignedEmulation) {
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
@@ -724,12 +959,11 @@ struct ConvertVectorMaskedLoad final
rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
if (!foldedIntraVectorOffset) {
result = dynamicallyExtractSubVector(
- rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
- op.getPassThru(), linearizedInfo.intraDataOffset, origElements);
+ rewriter, loc, cast<VectorValue>(result), op.getPassThru(),
+ linearizedInfo.intraDataOffset, origElements);
} else if (isUnalignedEmulation) {
- result =
- staticallyExtractSubvector(rewriter, loc, op.getType(), result,
- *foldedIntraVectorOffset, origElements);
+ result = staticallyExtractSubvector(
+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
rewriter.replaceOp(op, result);
@@ -812,9 +1046,8 @@ struct ConvertVectorTransferRead final
linearizedInfo.intraDataOffset,
origElements);
} else if (isUnalignedEmulation) {
- result =
- staticallyExtractSubvector(rewriter, loc, op.getType(), result,
- *foldedIntraVectorOffset, origElements);
+ result = staticallyExtractSubvector(
+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
rewriter.replaceOp(op, result);
@@ -1559,12 +1792,18 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
void vector::populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
- RewritePatternSet &patterns) {
+ RewritePatternSet &patterns, bool useAtomicWrites) {
- // Populate `vector.*` conversion patterns.
- patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
+ // Populate `vector.*` load conversion patterns.
+ // TODO: #119553 support atomicity
+ patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
ConvertVectorMaskedStore, ConvertVectorTransferRead>(
typeConverter, patterns.getContext());
+
+ // Populate `vector.*` store conversion patterns. The caller can choose
+ // to avoid emitting atomic operations and reduce it to load-modify-write
+ // sequence for stores if it is known there are no thread contentions.
+ patterns.insert<ConvertVectorStore>(patterns.getContext(), useAtomicWrites);
}
void vector::populateVectorNarrowTypeRewritePatterns(
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
new file mode 100644
index 00000000000000..9df595dae0f257
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
@@ -0,0 +1,119 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8 atomic-store=false" --cse --split-input-file %s | FileCheck %s
+
+// TODO: remove memref.alloc() in the tests to eliminate noises.
+// memref.alloc exists here because sub-byte vector data types such as i2
+// are currently not supported as input arguments.
+
+func.func @vector_store_i2_const_index_two_rmw(%arg0: vector<3xi2>) {
+ %0 = memref.alloc() : memref<3x3xi2>
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
+ return
+}
+// Load from bit [12:18), byte [1:2] of total 3 bytes, both bytes needs rmw.
+
+// CHECK: func @vector_store_i2_const_index_two_rmw(
+// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+
+// Part 1 RMW sequence
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[LOAD:.+]] = vector.load
+// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[UPCAST]], %[[INSERT]]
+// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
+// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C1]]]
+
+// Part 2 RMW sequence
+// CHECK: %[[OFFSET:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]], %[[CST0]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[CST1:.+]] = arith.constant dense<[true, false, false, false]> : vector<4xi1>
+// CHECK: %[[LOAD2:.+]] = vector.load
+// CHECK: %[[UPCAST2:.+]] = vector.bitcast %[[LOAD2]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[UPCAST2]], %[[INSERT2]]
+// CHECK: %[[DOWNCAST2:.+]] = vector.bitcast %[[SELECT2]]
+// CHECK: vector.store %[[DOWNCAST2]], %[[ALLOC]][%[[OFFSET]]]
+
+
+// -----
+
+func.func @vector_store_i2_rmw(%arg0: vector<7xi2>) {
+ %0 = memref.alloc() : memref<3x7xi2>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2>
+ return
+}
+
+// CHECK: func @vector_store_i2_rmw(
+// CHECK-SAME: %[[ARG0:.+]]:
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]}
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [3], strides = [1]}
+// First sub-width RMW:
+// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C1]]]
+// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[UPCAST]], %[[INSERT]]
+// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
+// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C1]]]
+
+// Full-width store:
+// CHECK: %[[INDEX:.+]] = arith.addi %[[C1]], %[[C1]]
+// CHECK: %[[EXTRACT1:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]}
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EXTRACT1]]
+// CHECK: vector.store %[[BITCAST]], %[[ALLOC]][%[[INDEX]]]
+
+// Second sub-width RMW:
+// CHECK: %[[INDEX2:.+]] = arith.addi %[[INDEX]], %[[C1]]
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]}
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]]
+// CHECK-SAME: {offsets = [0], strides = [1]}
+// CHECK: %[[CST1:.+]] = arith.constant dense<[true, true, false, false]>
+// CHECK: %[[LOAD1:.+]] = vector.load %[[ALLOC]][%[[INDEX2]]]
+// CHECK: %[[UPCAST1:.+]] = vector.bitcast %[[LOAD1]]
+// CHECK: %[[SELECT1:.+]] = arith.select %[[CST1]], %[[UPCAST1]], %[[INSERT2]]
+// CHECK: %[[DOWNCAST1:.+]] = vector.bitcast %[[SELECT1]]
+// CHECK: vector.store %[[DOWNCAST1]], %[[ALLOC]][%[[INDEX2]]]
+
+// -----
+
+func.func @vector_store_i2_single_rmw(%arg0: vector<1xi2>) {
+ %0 = memref.alloc() : memref<4x1xi2>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2>
+ return
+}
+
+// in this test, only emit 1 rmw store
+// CHECK: func @vector_store_i2_single_rmw(
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, true, false, false]>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]]
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C0]]] : memref<1xi8>, vector<1xi8>
+// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[UPCAST]], %[[INSERT]]
+// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
+// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C0]]]
+
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 4332e80feed421..b01f9165d9eb74 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -356,3 +356,140 @@ func.func @vector_maskedload_i2_constant_mask_unaligned(%passthru: vector<5xi2>)
// CHECK: %[[RESULT:.+]] = vector.extract_strided_slice %[[SELECT]]
// CHECK-SAME: {offsets = [1], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2>
// CHECK: return %[[RESULT]] : vector<5xi2>
+
+///----------------------------------------------------------------------------------------
+/// vector.store
+///----------------------------------------------------------------------------------------
+
+func.func @vector_store_i2_const_index_two_atomic(%arg0: vector<3xi2>) {
+ %0 = memref.alloc() : memref<3x3xi2>
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
+ return
+}
+
+// In this example, emit 2 atomic RMWs.
+// Load from bit [12:18), byte [1:2] of total 3 bytes, both bytes needs rmw.
+
+// CHECK-LABEL: func @vector_store_i2_const_index_two_atomic(
+// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]> : vector<4xi1>
+// CHECK: %[[CST_0:.+]] = arith.constant dense<0> : vector<4xi2>
+
+// Part 1 atomic RMW sequence
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST_0]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<3xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
+
+// Part 2 atomic RMW sequence
+// CHECK: %[[ADDR2:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT3]], %[[CST_0]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[CST1:.+]] = arith.constant dense<[true, false, false, false]> : vector<4xi1>
+// CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDR2]]] : memref<3xi8> {
+// CHECK: %[[ARG2:.+]]: i8):
+// CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8>
+// CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST4]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST5:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT4:.+]] = vector.extract %[[BITCAST5]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT4]] : i8
+
+// -----
+
+func.func @vector_store_i2_atomic(%arg0: vector<7xi2>) {
+ %0 = memref.alloc() : memref<3x7xi2>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2>
+ return
+}
+
+// In this example, emit 2 atomic RMWs and 1 non-atomic store:
+// CHECK-LABEL: func @vector_store_i2_atomic(
+// CHECK-SAME: %[[ARG0:.+]]: vector<7xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]> : vector<4xi1>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+
+// First atomic RMW:
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]} : vector<7xi2> to vector<1xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [3], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<6xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
+
+// Non-atomic store:
+// CHECK: %[[ADDR:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]} : vector<7xi2> to vector<4xi2>
+// CHECK: %[[BITCAST3:.+]] = vector.bitcast %[[EXTRACT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: vector.store %[[BITCAST3]], %[[ALLOC]][%[[ADDR]]] : memref<6xi8>, vector<1xi8>
+
+// Second atomic RMW:
+// CHECK: %[[ADDR2:.+]] = arith.addi %[[ADDR]], %[[C1]] : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]} : vector<7xi2> to vector<2xi2>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT3]], %[[CST0]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[CST1:.+]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
+// CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDR2]]] : memref<6xi8> {
+// CHECK: %[[ARG2:.+]]: i8):
+// CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8>
+// CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST4]] :
+// CHECK-SAME: vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST5:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT4:.+]] = vector.extract %[[BITCAST5]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT4]] : i8
+
+// -----
+
+func.func @vector_store_i2_single_atomic(%arg0: vector<1xi2>) {
+ %0 = memref.alloc() : memref<4x1xi2>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2>
+ return
+}
+
+// In this example, only emit 1 atomic store
+// CHECK-LABEL: func @vector_store_i2_single_atomic(
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, true, false, false]> : vector<4xi1>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]]
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2>
+
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C0]]] : memref<1xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
index 7401e470ed4f2c..9a3fac623fbd7d 100644
--- a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
+++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
@@ -99,7 +99,8 @@ struct TestEmulateNarrowTypePass
arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
- vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
+ vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns,
+ atomicStore);
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
@@ -118,6 +119,11 @@ struct TestEmulateNarrowTypePass
*this, "skip-memref-type-conversion",
llvm::cl::desc("disable memref type conversion (to test failures)"),
llvm::cl::init(false)};
+
+ Option<bool> atomicStore{
+ *this, "atomic-store",
+ llvm::cl::desc("use atomic store instead of load-modify-write"),
+ llvm::cl::init(true)};
};
} // namespace
More information about the Mlir-commits
mailing list