[Mlir-commits] [mlir] [MLIR] Implement emulation of static indexing subbyte type vector stores (PR #115922)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Dec 11 05:05:45 PST 2024
https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/115922
>From 72bdea0cc8c0d8975c89588e404245a82a9e35c0 Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Fri, 25 Oct 2024 15:19:42 +0000
Subject: [PATCH 01/12] Implement vector stores
---
.../Vector/Transforms/VectorRewritePatterns.h | 5 +-
.../Transforms/VectorEmulateNarrowType.cpp | 265 +++++++++++++++---
...tor-emulate-narrow-type-unaligned-rmw.mlir | 104 +++++++
.../vector-emulate-narrow-type-unaligned.mlir | 90 ++++++
.../Dialect/MemRef/TestEmulateNarrowType.cpp | 8 +-
5 files changed, 436 insertions(+), 36 deletions(-)
create mode 100644 mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-rmw.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index a59f06f3c1ef1b..64bb3a2204cfdc 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -364,10 +364,11 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
/// Appends patterns for emulating vector operations over narrow types with ops
-/// over wider types.
+/// over wider types. `useAtomicWrites` indicates whether to use atomic
+/// operations in the places where thread contention is possible.
void populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
- RewritePatternSet &patterns);
+ RewritePatternSet &patterns, bool useAtomicWrites = true);
/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
/// vector operations comprising `shuffle` and `bitwise` ops.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 87c30a733c363e..278b42a5b7c104 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -33,6 +33,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
#include <cstdint>
@@ -211,13 +212,10 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
/// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
/// emitting `vector.extract_strided_slice`.
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
- VectorType extractType, Value source,
- int64_t frontOffset,
+ Value source, int64_t frontOffset,
int64_t subvecSize) {
auto vectorType = cast<VectorType>(source.getType());
- assert((vectorType.getRank() == 1 && extractType.getRank() == 1) &&
- "expected 1-D source and destination types");
- (void)vectorType;
+ assert(vectorType.getRank() == 1 && "expected 1-D source types");
assert(frontOffset + subvecSize <= vectorType.getNumElements() &&
"subvector out of bounds");
@@ -228,9 +226,12 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
auto offsets = rewriter.getI64ArrayAttr({frontOffset});
auto sizes = rewriter.getI64ArrayAttr({subvecSize});
auto strides = rewriter.getI64ArrayAttr({1});
+
+ auto resultVectorType =
+ VectorType::get({subvecSize}, vectorType.getElementType());
return rewriter
- .create<vector::ExtractStridedSliceOp>(loc, extractType, source, offsets,
- sizes, strides)
+ .create<vector::ExtractStridedSliceOp>(loc, resultVectorType, source,
+ offsets, sizes, strides)
->getResult(0);
}
@@ -309,6 +310,76 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
newLoad);
}
+/// Atomically store a subbyte-sized value to memory, with a mask.
+static void atomicStore(OpBuilder &builder, Location loc,
+ TypedValue<MemRefType> emulatedMemref,
+ Value linearizedIndex, TypedValue<VectorType> value,
+ Value mask, int64_t) {
+ auto atomicOp = builder.create<memref::GenericAtomicRMWOp>(
+ loc, emulatedMemref, ValueRange{linearizedIndex});
+ Value origValue = atomicOp.getCurrentValue();
+
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(atomicOp.getBody());
+
+ // i8 -> <1xi8> -> <numSrcElemsPerDest x i.>
+ auto oneVectorType = VectorType::get({1}, origValue.getType());
+ auto fromElem = builder.create<vector::FromElementsOp>(loc, oneVectorType,
+ ValueRange{origValue});
+ auto vectorBitCast =
+ builder.create<vector::BitCastOp>(loc, value.getType(), fromElem);
+
+ auto select =
+ builder.create<arith::SelectOp>(loc, mask, value, vectorBitCast);
+ auto bitcast2 = builder.create<vector::BitCastOp>(loc, oneVectorType, select);
+ auto extract = builder.create<vector::ExtractOp>(loc, bitcast2, 0);
+ builder.create<memref::AtomicYieldOp>(loc, extract.getResult());
+}
+
+/// Generate a non-atomic read-modify-write sequence for subbyte storing.
+static void rmwStore(OpBuilder &rewriter, Location loc,
+ TypedValue<MemRefType> emulatedMemref,
+ Value linearizedIndex, TypedValue<VectorType> value,
+ Value mask, int64_t numSrcElemsPerDest) {
+ auto emulatedIOType =
+ VectorType::get({1}, emulatedMemref.getType().getElementType());
+ auto elemLoad = rewriter.create<vector::LoadOp>(
+ loc, emulatedIOType, emulatedMemref, ValueRange{linearizedIndex});
+ auto fromBitcast = rewriter.create<vector::BitCastOp>(
+ loc,
+ VectorType::get({numSrcElemsPerDest}, value.getType().getElementType()),
+ elemLoad);
+ auto select = rewriter.create<arith::SelectOp>(loc, mask, fromBitcast, value);
+ auto toBitcast =
+ rewriter.create<vector::BitCastOp>(loc, emulatedIOType, select);
+ rewriter.create<vector::StoreOp>(loc, toBitcast, emulatedMemref,
+ linearizedIndex);
+}
+
+static_assert(std::is_same_v<decltype(atomicStore), decltype(rmwStore)> &&
+ "`atomicStore` and `rmwStore` must have same signature, as per "
+ "the design to keep the code clean, which one to call is "
+ "determined by the `useAtomicWrites` flag.");
+
+// Extract a slice of a vector, and insert it into a byte vector.
+static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
+ Location loc, TypedValue<VectorType> vector,
+ int64_t sliceOffset, int64_t sliceNumElements,
+ int64_t byteOffset) {
+ auto vectorElementType = vector.getType().getElementType();
+ assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
+ "vector element must be a valid sub-byte type");
+ auto scale = 8 / vectorElementType.getIntOrFloatBitWidth();
+ auto emptyByteVector = rewriter.create<arith::ConstantOp>(
+ loc, VectorType::get({scale}, vectorElementType),
+ rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType)));
+ auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
+ sliceOffset, sliceNumElements);
+ auto inserted = staticallyInsertSubvector(rewriter, loc, extracted,
+ emptyByteVector, byteOffset);
+ return inserted;
+}
+
namespace {
//===----------------------------------------------------------------------===//
@@ -318,6 +389,10 @@ namespace {
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
using OpConversionPattern::OpConversionPattern;
+ ConvertVectorStore(MLIRContext *context, bool useAtomicWrites)
+ : OpConversionPattern<vector::StoreOp>(context),
+ useAtomicWrites_(useAtomicWrites) {}
+
LogicalResult
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -329,8 +404,9 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto loc = op.getLoc();
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
- Type oldElementType = op.getValueToStore().getType().getElementType();
- Type newElementType = convertedType.getElementType();
+ auto valueToStore = cast<TypedValue<VectorType>>(op.getValueToStore());
+ auto oldElementType = valueToStore.getType().getElementType();
+ auto newElementType = convertedType.getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
int dstBits = newElementType.getIntOrFloatBitWidth();
@@ -338,7 +414,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
return rewriter.notifyMatchFailure(
op, "only dstBits % srcBits == 0 supported");
}
- int scale = dstBits / srcBits;
+ int numSrcElemsPerDest = dstBits / srcBits;
// Adjust the number of elements to store when emulating narrow types.
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -353,15 +429,15 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
// vector<4xi8>
- auto origElements = op.getValueToStore().getType().getNumElements();
- if (origElements % scale != 0)
- return failure();
+ auto origElements = valueToStore.getType().getNumElements();
+ bool isUnalignedEmulation = origElements % numSrcElemsPerDest != 0;
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
OpFoldResult linearizedIndices;
- std::tie(std::ignore, linearizedIndices) =
+ memref::LinearizedMemRefInfo linearizedInfo;
+ std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
rewriter, loc, srcBits, dstBits,
stridedMetadata.getConstifiedMixedOffset(),
@@ -369,16 +445,137 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
stridedMetadata.getConstifiedMixedStrides(),
getAsOpFoldResult(adaptor.getIndices()));
- auto numElements = origElements / scale;
- auto bitCast = rewriter.create<vector::BitCastOp>(
- loc, VectorType::get(numElements, newElementType),
- op.getValueToStore());
+ auto foldedNumFrontPadElems =
+ isUnalignedEmulation
+ ? getConstantIntValue(linearizedInfo.intraDataOffset)
+ : 0;
- rewriter.replaceOpWithNewOp<vector::StoreOp>(
- op, bitCast.getResult(), adaptor.getBase(),
- getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+ if (!foldedNumFrontPadElems) {
+ // Unimplemented case for dynamic front padding size != 0
+ return failure();
+ }
+
+ auto emulatedMemref = cast<TypedValue<MemRefType>>(adaptor.getBase());
+
+ // Shortcut: conditions when subbyte store at the front is not needed:
+ // 1. The source vector size is multiple of byte size
+ // 2. The address of the store is aligned to the emulated width boundary
+ if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0) {
+ auto numElements = origElements / numSrcElemsPerDest;
+ auto bitCast = rewriter.create<vector::BitCastOp>(
+ loc, VectorType::get(numElements, newElementType),
+ op.getValueToStore());
+ rewriter.replaceOpWithNewOp<vector::StoreOp>(
+ op, bitCast.getResult(), emulatedMemref,
+ getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+ return success();
+ }
+
+ // The index into the target memref we are storing to
+ Value currentDestIndex =
+ getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
+ auto constantOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto subWidthStoreMaskType =
+ VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type());
+ // The index into the source vector we are currently processing
+ auto currentSourceIndex = 0;
+
+ // 1. Partial width store for the first byte, when the store address is not
+ // aligned to emulated width boundary, deal with the unaligned part so that
+ // the rest elements are aligned to width boundary.
+ auto frontSubWidthStoreElem =
+ (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
+ if (frontSubWidthStoreElem != 0) {
+ SmallVector<bool> frontMaskValues(numSrcElemsPerDest, false);
+ if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
+ std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
+ origElements, true);
+ frontSubWidthStoreElem = origElements;
+ } else {
+ std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
+ *foldedNumFrontPadElems, true);
+ }
+ auto frontMask = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
+
+ currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
+ auto value =
+ extractSliceIntoByte(rewriter, loc, valueToStore, 0,
+ frontSubWidthStoreElem, *foldedNumFrontPadElems);
+
+ subEmulatedWidthStore(rewriter, loc, emulatedMemref, currentDestIndex,
+ cast<TypedValue<VectorType>>(value),
+ frontMask.getResult(), numSrcElemsPerDest);
+
+ currentDestIndex = rewriter.create<arith::AddIOp>(
+ loc, rewriter.getIndexType(), currentDestIndex, constantOne);
+ }
+
+ if (currentSourceIndex >= origElements) {
+ rewriter.eraseOp(op);
+ return success();
+ }
+
+ // 2. Full width store. After the previous step, the store address is
+ // aligned to the emulated width boundary.
+ int64_t fullWidthStoreSize =
+ (origElements - currentSourceIndex) / numSrcElemsPerDest;
+ int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
+ if (fullWidthStoreSize != 0) {
+ auto fullWidthStorePart = staticallyExtractSubvector(
+ rewriter, loc, valueToStore, currentSourceIndex,
+ numNonFullWidthElements);
+
+ auto originType = dyn_cast<VectorType>(fullWidthStorePart.getType());
+ auto memrefElemType =
+ dyn_cast<MemRefType>(emulatedMemref.getType()).getElementType();
+ auto storeType = VectorType::get(
+ {originType.getNumElements() / numSrcElemsPerDest}, memrefElemType);
+ auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
+ fullWidthStorePart);
+ rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), emulatedMemref,
+ currentDestIndex);
+
+ currentSourceIndex += numNonFullWidthElements;
+ currentDestIndex = rewriter.create<arith::AddIOp>(
+ loc, rewriter.getIndexType(), currentDestIndex,
+ rewriter.create<arith::ConstantIndexOp>(loc, fullWidthStoreSize));
+ }
+
+ // 3. Deal with trailing elements that are aligned to the emulated width,
+ // but their length is smaller than the emulated width.
+ auto remainingElements = origElements - currentSourceIndex;
+ if (remainingElements != 0) {
+ auto subWidthStorePart = extractSliceIntoByte(
+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
+ currentSourceIndex, remainingElements, 0);
+
+ // Generate back mask
+ auto maskValues = SmallVector<bool>(numSrcElemsPerDest, 0);
+ std::fill_n(maskValues.begin(), remainingElements, 1);
+ auto backMask = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
+
+ subEmulatedWidthStore(rewriter, loc, emulatedMemref, currentDestIndex,
+ cast<TypedValue<VectorType>>(subWidthStorePart),
+ backMask.getResult(), numSrcElemsPerDest);
+ }
+
+ rewriter.eraseOp(op);
return success();
}
+
+ /// Store a subbyte-sized value to memory, with a mask. Depending on the
+ /// configuration, it could be an atomic store or an RMW sequence.
+ template <typename... Args>
+ void subEmulatedWidthStore(Args &&...args) const {
+ std::function<decltype(atomicStore)> storeFunc =
+ useAtomicWrites_ ? atomicStore : rmwStore;
+ storeFunc(std::forward<Args>(args)...);
+ }
+
+private:
+ const bool useAtomicWrites_;
};
//===----------------------------------------------------------------------===//
@@ -584,9 +781,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
linearizedInfo.intraDataOffset, origElements);
} else if (isUnalignedEmulation) {
- result =
- staticallyExtractSubvector(rewriter, loc, op.getType(), result,
- *foldedIntraVectorOffset, origElements);
+ result = staticallyExtractSubvector(
+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
rewriter.replaceOp(op, result);
return success();
@@ -745,9 +941,8 @@ struct ConvertVectorMaskedLoad final
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
op.getPassThru(), linearizedInfo.intraDataOffset, origElements);
} else if (isUnalignedEmulation) {
- result =
- staticallyExtractSubvector(rewriter, loc, op.getType(), result,
- *foldedIntraVectorOffset, origElements);
+ result = staticallyExtractSubvector(
+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
rewriter.replaceOp(op, result);
@@ -830,9 +1025,8 @@ struct ConvertVectorTransferRead final
linearizedInfo.intraDataOffset,
origElements);
} else if (isUnalignedEmulation) {
- result =
- staticallyExtractSubvector(rewriter, loc, op.getType(), result,
- *foldedIntraVectorOffset, origElements);
+ result = staticallyExtractSubvector(
+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
rewriter.replaceOp(op, result);
@@ -1577,12 +1771,17 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
void vector::populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
- RewritePatternSet &patterns) {
+ RewritePatternSet &patterns, bool useAtomicWrites) {
- // Populate `vector.*` conversion patterns.
- patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
+ // Populate `vector.*` load conversion patterns.
+ patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
ConvertVectorMaskedStore, ConvertVectorTransferRead>(
typeConverter, patterns.getContext());
+
+ // Populate `vector.*` store conversion patterns. The caller can choose
+ // to avoid emitting atomic operations and reduce it to load-modify-write
+ // sequence for stores if it is known there are no thread contentions.
+ patterns.insert<ConvertVectorStore>(patterns.getContext(), useAtomicWrites);
}
void vector::populateVectorNarrowTypeRewritePatterns(
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-rmw.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-rmw.mlir
new file mode 100644
index 00000000000000..fa4d9cb5e4d4c7
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-rmw.mlir
@@ -0,0 +1,104 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8 atomic-store=false" --cse --split-input-file %s | FileCheck %s
+
+// TODO: remove memref.alloc() in the tests to eliminate noises.
+// memref.alloc exists here because sub-byte vector data types such as i2
+// are currently not supported as input arguments.
+
+func.func @vector_store_i2_const_rmw(%arg0: vector<3xi2>) {
+ %0 = memref.alloc() : memref<3x3xi2>
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
+ return
+}
+// CHECK: func @vector_store_i2_const_rmw(
+// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[LOAD:.+]] = vector.load
+
+// Actual part to do RMW sequence
+// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[UPCAST]], %[[INSERT]]
+// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
+// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C1]]]
+
+// -----
+
+func.func @vector_store_i2_atomic(%arg0: vector<7xi2>) {
+ %0 = memref.alloc() : memref<3x7xi2>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2>
+ return
+}
+
+// CHECK: func @vector_store_i2_atomic(
+// CHECK-SAME: %[[ARG0:.+]]:
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]}
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [3], strides = [1]}
+// First sub-width RMW:
+// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C1]]]
+// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[UPCAST]], %[[INSERT]]
+// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
+// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C1]]]
+
+// Full-width store:
+// CHECK: %[[INDEX:.+]] = arith.addi %[[C1]], %[[C1]]
+// CHECK: %[[EXTRACT1:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]}
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EXTRACT1]]
+// CHECK: vector.store %[[BITCAST]], %[[ALLOC]][%[[INDEX]]]
+
+// Second sub-width RMW:
+// CHECK: %[[INDEX2:.+]] = arith.addi %[[INDEX]], %[[C1]]
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]}
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]]
+// CHECK-SAME: {offsets = [0], strides = [1]}
+// CHECK: %[[CST1:.+]] = arith.constant dense<[true, true, false, false]>
+// CHECK: %[[LOAD1:.+]] = vector.load %[[ALLOC]][%[[INDEX2]]]
+// CHECK: %[[UPCAST1:.+]] = vector.bitcast %[[LOAD1]]
+// CHECK: %[[SELECT1:.+]] = arith.select %[[CST1]], %[[UPCAST1]], %[[INSERT2]]
+// CHECK: %[[DOWNCAST1:.+]] = vector.bitcast %[[SELECT1]]
+// CHECK: vector.store %[[DOWNCAST1]], %[[ALLOC]][%[[INDEX2]]]
+
+// -----
+
+func.func @vector_store_i2_single_rmw(%arg0: vector<1xi2>) {
+ %0 = memref.alloc() : memref<4x1xi2>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2>
+ return
+}
+
+// in this test, only emit 1 rmw store
+// CHECK: func @vector_store_i2_single_rmw(
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, true, false, false]>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]]
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C0]]] : memref<1xi8>, vector<1xi8>
+// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[UPCAST]], %[[INSERT]]
+// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
+// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C0]]]
+
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 721c8a8d5d2034..fd526ada6cb7b2 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -336,3 +336,93 @@ func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>)
// CHECK: %[[RESULT:.+]] = vector.extract_strided_slice %[[SELECT]]
// CHECK-SAME: {offsets = [1], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2>
// CHECK: return %[[RESULT]] : vector<5xi2>
+
+///----------------------------------------------------------------------------------------
+/// vector.store
+///----------------------------------------------------------------------------------------
+
+// -----
+
+func.func @vector_store_i2_atomic(%arg0: vector<7xi2>) {
+ %0 = memref.alloc() : memref<3x7xi2>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2>
+ return
+}
+
+// In this example, emit 2 atomic RMWs and 1 non-atomic store:
+// CHECK-LABEL: func @vector_store_i2_atomic(
+// CHECK-SAME: %[[ARG0:.+]]: vector<7xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]> : vector<4xi1>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+
+// First atomic RMW:
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]} : vector<7xi2> to vector<1xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [3], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<6xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
+
+// Non-atomic store:
+// CHECK: %[[ADDR:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]} : vector<7xi2> to vector<4xi2>
+// CHECK: %[[BITCAST3:.+]] = vector.bitcast %[[EXTRACT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: vector.store %[[BITCAST3]], %[[ALLOC]][%[[ADDR]]] : memref<6xi8>, vector<1xi8>
+
+// Second atomic RMW:
+// CHECK: %[[ADDR2:.+]] = arith.addi %[[ADDR]], %[[C1]] : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]} : vector<7xi2> to vector<2xi2>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT3]], %[[CST0]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[CST1:.+]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
+// CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDR2]]] : memref<6xi8> {
+// CHECK: %[[ARG2:.+]]: i8):
+// CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8>
+// CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST4]] :
+// CHECK-SAME: vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST5:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT4:.+]] = vector.extract %[[BITCAST5]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT4]] : i8
+
+// -----
+
+func.func @vector_store_i2_single_atomic(%arg0: vector<1xi2>) {
+ %0 = memref.alloc() : memref<4x1xi2>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2>
+ return
+}
+
+// In this example, only emit 1 atomic store
+// CHECK-LABEL: func @vector_store_i2_single_atomic(
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, true, false, false]> : vector<4xi1>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]]
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2>
+
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C0]]] : memref<1xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
index 7401e470ed4f2c..9a3fac623fbd7d 100644
--- a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
+++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
@@ -99,7 +99,8 @@ struct TestEmulateNarrowTypePass
arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
- vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
+ vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns,
+ atomicStore);
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
@@ -118,6 +119,11 @@ struct TestEmulateNarrowTypePass
*this, "skip-memref-type-conversion",
llvm::cl::desc("disable memref type conversion (to test failures)"),
llvm::cl::init(false)};
+
+ Option<bool> atomicStore{
+ *this, "atomic-store",
+ llvm::cl::desc("use atomic store instead of load-modify-write"),
+ llvm::cl::init(true)};
};
} // namespace
>From 99ee5d01675738f24cadb1e2c5e123a4a0917797 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 3 Dec 2024 01:44:40 +0800
Subject: [PATCH 02/12] Refactoring
---
.../Transforms/VectorEmulateNarrowType.cpp | 64 ++++++++++---------
1 file changed, 34 insertions(+), 30 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 278b42a5b7c104..d5012e387944f8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -46,6 +46,9 @@ using namespace mlir;
#define DBGSNL() (llvm::dbgs() << "\n")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+using VectorValue = TypedValue<VectorType>;
+using MemRefValue = TypedValue<MemRefType>;
+
/// Returns a compressed mask for the emulated vector. For example, when
/// emulating an eight-element `i8` vector with `i32` (i.e. when the source
/// elements span two dest elements), this method compresses `vector<8xi1>`
@@ -255,8 +258,8 @@ static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc,
/// function emits multiple `vector.extract` and `vector.insert` ops, so only
/// use it when `offset` cannot be folded into a constant value.
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
- TypedValue<VectorType> source,
- Value dest, OpFoldResult offset,
+ VectorValue source, Value dest,
+ OpFoldResult offset,
int64_t numElementsToExtract) {
for (int i = 0; i < numElementsToExtract; ++i) {
Value extractLoc =
@@ -273,8 +276,8 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
/// Inserts a 1-D subvector into a 1-D `dest` vector at index `destOffsetVar`.
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
- TypedValue<VectorType> source,
- Value dest, OpFoldResult destOffsetVar,
+ VectorValue source, Value dest,
+ OpFoldResult destOffsetVar,
size_t length) {
assert(length > 0 && "length must be greater than 0");
Value destOffsetVal =
@@ -295,11 +298,12 @@ static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
/// specifically, use `emulatedElemType` for loading a vector of `origElemType`.
/// The load location is given by `base` and `linearizedIndices`, and the
/// load size is given by `numEmulatedElementsToLoad`.
-static TypedValue<VectorType>
-emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
- OpFoldResult linearizedIndices,
- int64_t numEmultedElementsToLoad, Type origElemType,
- Type emulatedElemType) {
+static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc,
+ Value base,
+ OpFoldResult linearizedIndices,
+ int64_t numEmultedElementsToLoad,
+ Type origElemType,
+ Type emulatedElemType) {
auto scale = emulatedElemType.getIntOrFloatBitWidth() /
origElemType.getIntOrFloatBitWidth();
auto newLoad = rewriter.create<vector::LoadOp>(
@@ -312,9 +316,9 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
/// Atomically store a subbyte-sized value to memory, with a mask.
static void atomicStore(OpBuilder &builder, Location loc,
- TypedValue<MemRefType> emulatedMemref,
- Value linearizedIndex, TypedValue<VectorType> value,
- Value mask, int64_t) {
+ MemRefValue emulatedMemref, Value linearizedIndex,
+ VectorValue value, Value mask,
+ int64_t numSrcElemsPerDest) {
auto atomicOp = builder.create<memref::GenericAtomicRMWOp>(
loc, emulatedMemref, ValueRange{linearizedIndex});
Value origValue = atomicOp.getCurrentValue();
@@ -338,9 +342,9 @@ static void atomicStore(OpBuilder &builder, Location loc,
/// Generate a non-atomic read-modify-write sequence for subbyte storing.
static void rmwStore(OpBuilder &rewriter, Location loc,
- TypedValue<MemRefType> emulatedMemref,
- Value linearizedIndex, TypedValue<VectorType> value,
- Value mask, int64_t numSrcElemsPerDest) {
+ MemRefValue emulatedMemref, Value linearizedIndex,
+ VectorValue value, Value mask,
+ int64_t numSrcElemsPerDest) {
auto emulatedIOType =
VectorType::get({1}, emulatedMemref.getType().getElementType());
auto elemLoad = rewriter.create<vector::LoadOp>(
@@ -363,7 +367,7 @@ static_assert(std::is_same_v<decltype(atomicStore), decltype(rmwStore)> &&
// Extract a slice of a vector, and insert it into a byte vector.
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
- Location loc, TypedValue<VectorType> vector,
+ Location loc, VectorValue vector,
int64_t sliceOffset, int64_t sliceNumElements,
int64_t byteOffset) {
auto vectorElementType = vector.getType().getElementType();
@@ -404,7 +408,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto loc = op.getLoc();
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
- auto valueToStore = cast<TypedValue<VectorType>>(op.getValueToStore());
+ auto valueToStore = cast<VectorValue>(op.getValueToStore());
auto oldElementType = valueToStore.getType().getElementType();
auto newElementType = convertedType.getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
@@ -455,7 +459,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
return failure();
}
- auto emulatedMemref = cast<TypedValue<MemRefType>>(adaptor.getBase());
+ auto emulatedMemref = cast<MemRefValue>(adaptor.getBase());
// Shortcut: conditions when subbyte store at the front is not needed:
// 1. The source vector size is multiple of byte size
@@ -504,8 +508,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
frontSubWidthStoreElem, *foldedNumFrontPadElems);
subEmulatedWidthStore(rewriter, loc, emulatedMemref, currentDestIndex,
- cast<TypedValue<VectorType>>(value),
- frontMask.getResult(), numSrcElemsPerDest);
+ cast<VectorValue>(value), frontMask.getResult(),
+ numSrcElemsPerDest);
currentDestIndex = rewriter.create<arith::AddIOp>(
loc, rewriter.getIndexType(), currentDestIndex, constantOne);
@@ -546,9 +550,9 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// but their length is smaller than the emulated width.
auto remainingElements = origElements - currentSourceIndex;
if (remainingElements != 0) {
- auto subWidthStorePart = extractSliceIntoByte(
- rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
- currentSourceIndex, remainingElements, 0);
+ auto subWidthStorePart =
+ extractSliceIntoByte(rewriter, loc, cast<VectorValue>(valueToStore),
+ currentSourceIndex, remainingElements, 0);
// Generate back mask
auto maskValues = SmallVector<bool>(numSrcElemsPerDest, 0);
@@ -557,7 +561,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
subEmulatedWidthStore(rewriter, loc, emulatedMemref, currentDestIndex,
- cast<TypedValue<VectorType>>(subWidthStorePart),
+ cast<VectorValue>(subWidthStorePart),
backMask.getResult(), numSrcElemsPerDest);
}
@@ -778,7 +782,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
auto resultVector = rewriter.create<arith::ConstantOp>(
loc, op.getType(), rewriter.getZeroAttr(op.getType()));
result = dynamicallyExtractSubVector(
- rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
+ rewriter, loc, cast<VectorValue>(result), resultVector,
linearizedInfo.intraDataOffset, origElements);
} else if (isUnalignedEmulation) {
result = staticallyExtractSubvector(
@@ -899,8 +903,8 @@ struct ConvertVectorMaskedLoad final
loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
if (!foldedIntraVectorOffset) {
passthru = dynamicallyInsertSubVector(
- rewriter, loc, dyn_cast<TypedValue<VectorType>>(passthru),
- emptyVector, linearizedInfo.intraDataOffset, origElements);
+ rewriter, loc, cast<VectorValue>(passthru), emptyVector,
+ linearizedInfo.intraDataOffset, origElements);
} else if (isUnalignedEmulation) {
passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
*foldedIntraVectorOffset);
@@ -927,7 +931,7 @@ struct ConvertVectorMaskedLoad final
loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
if (!foldedIntraVectorOffset) {
mask = dynamicallyInsertSubVector(
- rewriter, loc, dyn_cast<TypedValue<VectorType>>(mask), emptyMask,
+ rewriter, loc, cast<VectorValue>(mask), emptyMask,
linearizedInfo.intraDataOffset, origElements);
} else if (isUnalignedEmulation) {
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
@@ -938,8 +942,8 @@ struct ConvertVectorMaskedLoad final
rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
if (!foldedIntraVectorOffset) {
result = dynamicallyExtractSubVector(
- rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
- op.getPassThru(), linearizedInfo.intraDataOffset, origElements);
+ rewriter, loc, cast<VectorValue>(result), op.getPassThru(),
+ linearizedInfo.intraDataOffset, origElements);
} else if (isUnalignedEmulation) {
result = staticallyExtractSubvector(
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
>From c329d898bcf941eb9fecb208f18759a844b40319 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 4 Dec 2024 18:41:48 +0800
Subject: [PATCH 03/12] updates
---
.../Transforms/VectorEmulateNarrowType.cpp | 47 +++++++++----------
.../vector-emulate-narrow-type-unaligned.mlir | 2 -
2 files changed, 23 insertions(+), 26 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index d5012e387944f8..2aca8850d075c2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -316,11 +316,11 @@ static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc,
/// Atomically store a subbyte-sized value to memory, with a mask.
static void atomicStore(OpBuilder &builder, Location loc,
- MemRefValue emulatedMemref, Value linearizedIndex,
- VectorValue value, Value mask,
+ MemRefValue linearizedMemref, Value linearizedIndex,
+ VectorValue valueToStore, Value mask,
int64_t numSrcElemsPerDest) {
auto atomicOp = builder.create<memref::GenericAtomicRMWOp>(
- loc, emulatedMemref, ValueRange{linearizedIndex});
+ loc, linearizedMemref, ValueRange{linearizedIndex});
Value origValue = atomicOp.getCurrentValue();
OpBuilder::InsertionGuard guard(builder);
@@ -331,10 +331,10 @@ static void atomicStore(OpBuilder &builder, Location loc,
auto fromElem = builder.create<vector::FromElementsOp>(loc, oneVectorType,
ValueRange{origValue});
auto vectorBitCast =
- builder.create<vector::BitCastOp>(loc, value.getType(), fromElem);
+ builder.create<vector::BitCastOp>(loc, valueToStore.getType(), fromElem);
auto select =
- builder.create<arith::SelectOp>(loc, mask, value, vectorBitCast);
+ builder.create<arith::SelectOp>(loc, mask, valueToStore, vectorBitCast);
auto bitcast2 = builder.create<vector::BitCastOp>(loc, oneVectorType, select);
auto extract = builder.create<vector::ExtractOp>(loc, bitcast2, 0);
builder.create<memref::AtomicYieldOp>(loc, extract.getResult());
@@ -342,13 +342,13 @@ static void atomicStore(OpBuilder &builder, Location loc,
/// Generate a non-atomic read-modify-write sequence for subbyte storing.
static void rmwStore(OpBuilder &rewriter, Location loc,
- MemRefValue emulatedMemref, Value linearizedIndex,
+ MemRefValue linearizedMemref, Value linearizedIndex,
VectorValue value, Value mask,
int64_t numSrcElemsPerDest) {
auto emulatedIOType =
- VectorType::get({1}, emulatedMemref.getType().getElementType());
+ VectorType::get({1}, linearizedMemref.getType().getElementType());
auto elemLoad = rewriter.create<vector::LoadOp>(
- loc, emulatedIOType, emulatedMemref, ValueRange{linearizedIndex});
+ loc, emulatedIOType, linearizedMemref, ValueRange{linearizedIndex});
auto fromBitcast = rewriter.create<vector::BitCastOp>(
loc,
VectorType::get({numSrcElemsPerDest}, value.getType().getElementType()),
@@ -356,15 +356,10 @@ static void rmwStore(OpBuilder &rewriter, Location loc,
auto select = rewriter.create<arith::SelectOp>(loc, mask, fromBitcast, value);
auto toBitcast =
rewriter.create<vector::BitCastOp>(loc, emulatedIOType, select);
- rewriter.create<vector::StoreOp>(loc, toBitcast, emulatedMemref,
+ rewriter.create<vector::StoreOp>(loc, toBitcast, linearizedMemref,
linearizedIndex);
}
-static_assert(std::is_same_v<decltype(atomicStore), decltype(rmwStore)> &&
- "`atomicStore` and `rmwStore` must have same signature, as per "
- "the design to keep the code clean, which one to call is "
- "determined by the `useAtomicWrites` flag.");
-
// Extract a slice of a vector, and insert it into a byte vector.
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
Location loc, VectorValue vector,
@@ -459,7 +454,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
return failure();
}
- auto emulatedMemref = cast<MemRefValue>(adaptor.getBase());
+ auto linearizedMemref = cast<MemRefValue>(adaptor.getBase());
// Shortcut: conditions when subbyte store at the front is not needed:
// 1. The source vector size is multiple of byte size
@@ -470,7 +465,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
loc, VectorType::get(numElements, newElementType),
op.getValueToStore());
rewriter.replaceOpWithNewOp<vector::StoreOp>(
- op, bitCast.getResult(), emulatedMemref,
+ op, bitCast.getResult(), linearizedMemref,
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
return success();
}
@@ -507,7 +502,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
extractSliceIntoByte(rewriter, loc, valueToStore, 0,
frontSubWidthStoreElem, *foldedNumFrontPadElems);
- subEmulatedWidthStore(rewriter, loc, emulatedMemref, currentDestIndex,
+ subEmulatedWidthStore(rewriter, loc, linearizedMemref, currentDestIndex,
cast<VectorValue>(value), frontMask.getResult(),
numSrcElemsPerDest);
@@ -525,20 +520,19 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
int64_t fullWidthStoreSize =
(origElements - currentSourceIndex) / numSrcElemsPerDest;
int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
- if (fullWidthStoreSize != 0) {
+ if (fullWidthStoreSize > 0) {
auto fullWidthStorePart = staticallyExtractSubvector(
rewriter, loc, valueToStore, currentSourceIndex,
numNonFullWidthElements);
- auto originType = dyn_cast<VectorType>(fullWidthStorePart.getType());
- auto memrefElemType =
- dyn_cast<MemRefType>(emulatedMemref.getType()).getElementType();
+ auto originType = cast<VectorType>(fullWidthStorePart.getType());
+ auto memrefElemType = getElementTypeOrSelf(linearizedMemref.getType());
auto storeType = VectorType::get(
{originType.getNumElements() / numSrcElemsPerDest}, memrefElemType);
auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
fullWidthStorePart);
- rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), emulatedMemref,
- currentDestIndex);
+ rewriter.create<vector::StoreOp>(loc, bitCast.getResult(),
+ linearizedMemref, currentDestIndex);
currentSourceIndex += numNonFullWidthElements;
currentDestIndex = rewriter.create<arith::AddIOp>(
@@ -560,7 +554,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto backMask = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
- subEmulatedWidthStore(rewriter, loc, emulatedMemref, currentDestIndex,
+ subEmulatedWidthStore(rewriter, loc, linearizedMemref, currentDestIndex,
cast<VectorValue>(subWidthStorePart),
backMask.getResult(), numSrcElemsPerDest);
}
@@ -573,6 +567,11 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
/// configuration, it could be an atomic store or an RMW sequence.
template <typename... Args>
void subEmulatedWidthStore(Args &&...args) const {
+ static_assert(
+ std::is_same_v<decltype(atomicStore), decltype(rmwStore)> &&
+ "`atomicStore` and `rmwStore` must have same signature, as per "
+ "the design to keep the code clean, which one to call is "
+ "determined by the `useAtomicWrites` flag.");
std::function<decltype(atomicStore)> storeFunc =
useAtomicWrites_ ? atomicStore : rmwStore;
storeFunc(std::forward<Args>(args)...);
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index fd526ada6cb7b2..e37053c25ff066 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -341,8 +341,6 @@ func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>)
/// vector.store
///----------------------------------------------------------------------------------------
-// -----
-
func.func @vector_store_i2_atomic(%arg0: vector<7xi2>) {
%0 = memref.alloc() : memref<3x7xi2>
%c0 = arith.constant 0 : index
>From c84bc0ea5db3003753b982660e215e92eecd844f Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 4 Dec 2024 23:46:43 +0800
Subject: [PATCH 04/12] update commnets
---
.../Transforms/VectorEmulateNarrowType.cpp | 27 ++++++++++++++++---
1 file changed, 23 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 2aca8850d075c2..b04a76c0c1c43e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -314,7 +314,17 @@ static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc,
newLoad);
}
-/// Atomically store a subbyte-sized value to memory, with a mask.
+/// Emits `memref.generic_atomic_rmw` op to store a subbyte-sized value to a
+/// byte in memory, with a mask. The `valueToStore` is a vector of subbyte-sized
+/// elements, with size of 8 bits, and the mask is used to select which elements
+/// to store.
+///
+/// Before:
+/// memory = |ab|cd|ef|12| : <4xi2> (<1xi8>)
+/// valueToStore = |01|23|45|67| : vector<4xi2>
+/// mask = |0|0|1|1| : vector<4xi1>
+/// After:
+/// memory = |ab|cd|45|67| : <4xi2> (<1xi8>)
static void atomicStore(OpBuilder &builder, Location loc,
MemRefValue linearizedMemref, Value linearizedIndex,
VectorValue valueToStore, Value mask,
@@ -326,7 +336,7 @@ static void atomicStore(OpBuilder &builder, Location loc,
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(atomicOp.getBody());
- // i8 -> <1xi8> -> <numSrcElemsPerDest x i.>
+ // i8 -> <1xi8> -> <numSrcElemsPerDest x i.>:
auto oneVectorType = VectorType::get({1}, origValue.getType());
auto fromElem = builder.create<vector::FromElementsOp>(loc, oneVectorType,
ValueRange{origValue});
@@ -341,6 +351,7 @@ static void atomicStore(OpBuilder &builder, Location loc,
}
/// Generate a non-atomic read-modify-write sequence for subbyte storing.
+/// It has similar logic to `atomicStore`, but without the atomicity.
static void rmwStore(OpBuilder &rewriter, Location loc,
MemRefValue linearizedMemref, Value linearizedIndex,
VectorValue value, Value mask,
@@ -360,7 +371,15 @@ static void rmwStore(OpBuilder &rewriter, Location loc,
linearizedIndex);
}
-// Extract a slice of a vector, and insert it into a byte vector.
+/// Extract a slice from vector `vector`, with the size of `sliceNumElements`,
+/// and insert it into a zero, byte vector at offset `byteOffset`. For example:
+/// Inputs:
+/// vector = |01|23|45|67| : vector<4xi2>
+/// sliceOffset = 1
+/// sliceNumElements = 2
+/// byteOffset = 2
+/// Output:
+/// vector = |00|00|23|45| : vector<4xi2>
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
Location loc, VectorValue vector,
int64_t sliceOffset, int64_t sliceNumElements,
@@ -484,7 +503,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// the rest elements are aligned to width boundary.
auto frontSubWidthStoreElem =
(numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
- if (frontSubWidthStoreElem != 0) {
+ if (frontSubWidthStoreElem > 0) {
SmallVector<bool> frontMaskValues(numSrcElemsPerDest, false);
if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
>From 11be4862edc8c2f7c11faac148714a8e84ddeaf7 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Thu, 5 Dec 2024 00:27:09 +0800
Subject: [PATCH 05/12] updates
---
.../Vector/Transforms/VectorEmulateNarrowType.cpp | 15 +++++++++++----
1 file changed, 11 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index b04a76c0c1c43e..4ca492bd1b58bb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -329,6 +329,7 @@ static void atomicStore(OpBuilder &builder, Location loc,
MemRefValue linearizedMemref, Value linearizedIndex,
VectorValue valueToStore, Value mask,
int64_t numSrcElemsPerDest) {
+ assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
auto atomicOp = builder.create<memref::GenericAtomicRMWOp>(
loc, linearizedMemref, ValueRange{linearizedIndex});
Value origValue = atomicOp.getCurrentValue();
@@ -336,7 +337,6 @@ static void atomicStore(OpBuilder &builder, Location loc,
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(atomicOp.getBody());
- // i8 -> <1xi8> -> <numSrcElemsPerDest x i.>:
auto oneVectorType = VectorType::get({1}, origValue.getType());
auto fromElem = builder.create<vector::FromElementsOp>(loc, oneVectorType,
ValueRange{origValue});
@@ -354,17 +354,20 @@ static void atomicStore(OpBuilder &builder, Location loc,
/// It has similar logic to `atomicStore`, but without the atomicity.
static void rmwStore(OpBuilder &rewriter, Location loc,
MemRefValue linearizedMemref, Value linearizedIndex,
- VectorValue value, Value mask,
+ VectorValue valueToStore, Value mask,
int64_t numSrcElemsPerDest) {
+ assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
auto emulatedIOType =
VectorType::get({1}, linearizedMemref.getType().getElementType());
auto elemLoad = rewriter.create<vector::LoadOp>(
loc, emulatedIOType, linearizedMemref, ValueRange{linearizedIndex});
auto fromBitcast = rewriter.create<vector::BitCastOp>(
loc,
- VectorType::get({numSrcElemsPerDest}, value.getType().getElementType()),
+ VectorType::get({numSrcElemsPerDest},
+ valueToStore.getType().getElementType()),
elemLoad);
- auto select = rewriter.create<arith::SelectOp>(loc, mask, fromBitcast, value);
+ auto select =
+ rewriter.create<arith::SelectOp>(loc, mask, fromBitcast, valueToStore);
auto toBitcast =
rewriter.create<vector::BitCastOp>(loc, emulatedIOType, select);
rewriter.create<vector::StoreOp>(loc, toBitcast, linearizedMemref,
@@ -384,7 +387,11 @@ static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
Location loc, VectorValue vector,
int64_t sliceOffset, int64_t sliceNumElements,
int64_t byteOffset) {
+ assert(vector.getType().getRank() == 1 && "expected 1-D vector");
auto vectorElementType = vector.getType().getElementType();
+ assert(
+ sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&
+ "sliceNumElements * vector element size must be less than or equal to 8");
assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
"vector element must be a valid sub-byte type");
auto scale = 8 / vectorElementType.getIntOrFloatBitWidth();
>From 6c05b72096ddd96d7e616c82ec06ba0e078139ab Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Thu, 5 Dec 2024 08:50:50 +0800
Subject: [PATCH 06/12] fix according to comments
---
.../Vector/Transforms/VectorRewritePatterns.h | 5 +++--
.../Transforms/VectorEmulateNarrowType.cpp | 16 +++++++++-------
2 files changed, 12 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 64bb3a2204cfdc..43478aacb50a14 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -364,8 +364,9 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
/// Appends patterns for emulating vector operations over narrow types with ops
-/// over wider types. `useAtomicWrites` indicates whether to use atomic
-/// operations in the places where thread contention is possible.
+/// over wider types. The `useAtomicWrites` indicates whether to use
+/// op `memref.generic_atomic_rmw` to perform atomic subbyte storing, or just a
+/// rmw sequence otherwise.
void populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns, bool useAtomicWrites = true);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 4ca492bd1b58bb..3c5534c4a932ad 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -319,12 +319,14 @@ static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc,
/// elements, with size of 8 bits, and the mask is used to select which elements
/// to store.
///
-/// Before:
-/// memory = |ab|cd|ef|12| : <4xi2> (<1xi8>)
-/// valueToStore = |01|23|45|67| : vector<4xi2>
+/// Inputs:
+/// linearizedMemref = |a|b|c|d| : <4xi2> (<1xi8>)
+/// linearizedIndex = 2
+/// valueToStore = |e|f|g|h| : vector<4xi2>
/// mask = |0|0|1|1| : vector<4xi1>
-/// After:
-/// memory = |ab|cd|45|67| : <4xi2> (<1xi8>)
+///
+/// Result:
+/// linearizedMemref = |a|b|g|h| : <4xi2> (<1xi8>)
static void atomicStore(OpBuilder &builder, Location loc,
MemRefValue linearizedMemref, Value linearizedIndex,
VectorValue valueToStore, Value mask,
@@ -374,8 +376,8 @@ static void rmwStore(OpBuilder &rewriter, Location loc,
linearizedIndex);
}
-/// Extract a slice from vector `vector`, with the size of `sliceNumElements`,
-/// and insert it into a zero, byte vector at offset `byteOffset`. For example:
+/// Extract `sliceNumElements` from source `vector` at `sliceOffset`,
+/// and insert it into an empty vector at offset `byteOffset`.
/// Inputs:
/// vector = |01|23|45|67| : vector<4xi2>
/// sliceOffset = 1
>From 40acc9984e7b672730f778ec0c83dc1565021518 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Sat, 7 Dec 2024 23:35:07 +0800
Subject: [PATCH 07/12] Updates according to comments
---
....mlir => vector-emulate-narrow-type-unaligned-non-atomic.mlir} | 0
1 file changed, 0 insertions(+), 0 deletions(-)
rename mlir/test/Dialect/Vector/{vector-emulate-narrow-type-unaligned-rmw.mlir => vector-emulate-narrow-type-unaligned-non-atomic.mlir} (100%)
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-rmw.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
similarity index 100%
rename from mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-rmw.mlir
rename to mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
>From 62f9b406bf4530a7e751c7acc1edc4900a0bf0fa Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 10 Dec 2024 21:43:23 +0800
Subject: [PATCH 08/12] update for comments
---
.../Transforms/VectorEmulateNarrowType.cpp | 57 +++++++++----------
1 file changed, 27 insertions(+), 30 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 3c5534c4a932ad..1d4c546db0bf0e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -320,17 +320,20 @@ static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc,
/// to store.
///
/// Inputs:
-/// linearizedMemref = |a|b|c|d| : <4xi2> (<1xi8>)
+/// linearizedMemref = |2|2|2|2| : <4xi2> (<1xi8>)
/// linearizedIndex = 2
-/// valueToStore = |e|f|g|h| : vector<4xi2>
+/// valueToStore = |3|3|3|3| : vector<4xi2>
/// mask = |0|0|1|1| : vector<4xi1>
///
/// Result:
-/// linearizedMemref = |a|b|g|h| : <4xi2> (<1xi8>)
+/// linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
static void atomicStore(OpBuilder &builder, Location loc,
MemRefValue linearizedMemref, Value linearizedIndex,
- VectorValue valueToStore, Value mask,
- int64_t numSrcElemsPerDest) {
+ VectorValue valueToStore, Value mask) {
+ // `numSrcElemsPerDest` is not used in this function, but to keep the function
+ // signature consistent with `rmwStore` so as to simplify the pattern to
+ // invoke.
+
assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
auto atomicOp = builder.create<memref::GenericAtomicRMWOp>(
loc, linearizedMemref, ValueRange{linearizedIndex});
@@ -354,37 +357,33 @@ static void atomicStore(OpBuilder &builder, Location loc,
/// Generate a non-atomic read-modify-write sequence for subbyte storing.
/// It has similar logic to `atomicStore`, but without the atomicity.
-static void rmwStore(OpBuilder &rewriter, Location loc,
+static void rmwStore(OpBuilder &builder, Location loc,
MemRefValue linearizedMemref, Value linearizedIndex,
- VectorValue valueToStore, Value mask,
- int64_t numSrcElemsPerDest) {
+ VectorValue valueToStore, Value mask) {
assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
- auto emulatedIOType =
+ auto oneElemVecType =
VectorType::get({1}, linearizedMemref.getType().getElementType());
- auto elemLoad = rewriter.create<vector::LoadOp>(
- loc, emulatedIOType, linearizedMemref, ValueRange{linearizedIndex});
- auto fromBitcast = rewriter.create<vector::BitCastOp>(
- loc,
- VectorType::get({numSrcElemsPerDest},
- valueToStore.getType().getElementType()),
- elemLoad);
- auto select =
- rewriter.create<arith::SelectOp>(loc, mask, fromBitcast, valueToStore);
- auto toBitcast =
- rewriter.create<vector::BitCastOp>(loc, emulatedIOType, select);
- rewriter.create<vector::StoreOp>(loc, toBitcast, linearizedMemref,
- linearizedIndex);
+ auto origValue = builder.create<vector::LoadOp>(
+ loc, oneElemVecType, linearizedMemref, ValueRange{linearizedIndex});
+ auto castedValue =
+ builder.create<vector::BitCastOp>(loc, valueToStore.getType(), origValue);
+ auto result =
+ builder.create<arith::SelectOp>(loc, mask, castedValue, valueToStore);
+ auto resultBitcast =
+ builder.create<vector::BitCastOp>(loc, oneElemVecType, result);
+ builder.create<vector::StoreOp>(loc, resultBitcast, linearizedMemref,
+ linearizedIndex);
}
/// Extract `sliceNumElements` from source `vector` at `sliceOffset`,
/// and insert it into an empty vector at offset `byteOffset`.
/// Inputs:
-/// vector = |01|23|45|67| : vector<4xi2>
+/// vector = |1|2|3|4| : vector<4xi2>
/// sliceOffset = 1
/// sliceNumElements = 2
/// byteOffset = 2
/// Output:
-/// vector = |00|00|23|45| : vector<4xi2>
+/// vector = |0|0|2|3| : vector<4xi2>
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
Location loc, VectorValue vector,
int64_t sliceOffset, int64_t sliceNumElements,
@@ -402,9 +401,8 @@ static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType)));
auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
sliceOffset, sliceNumElements);
- auto inserted = staticallyInsertSubvector(rewriter, loc, extracted,
- emptyByteVector, byteOffset);
- return inserted;
+ return staticallyInsertSubvector(rewriter, loc, extracted, emptyByteVector,
+ byteOffset);
}
namespace {
@@ -531,8 +529,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
frontSubWidthStoreElem, *foldedNumFrontPadElems);
subEmulatedWidthStore(rewriter, loc, linearizedMemref, currentDestIndex,
- cast<VectorValue>(value), frontMask.getResult(),
- numSrcElemsPerDest);
+ cast<VectorValue>(value), frontMask.getResult());
currentDestIndex = rewriter.create<arith::AddIOp>(
loc, rewriter.getIndexType(), currentDestIndex, constantOne);
@@ -584,7 +581,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
subEmulatedWidthStore(rewriter, loc, linearizedMemref, currentDestIndex,
cast<VectorValue>(subWidthStorePart),
- backMask.getResult(), numSrcElemsPerDest);
+ backMask.getResult());
}
rewriter.eraseOp(op);
>From 15d2ad3ea270f1eba314a707749e1b8ea78a0159 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 10 Dec 2024 21:53:55 +0800
Subject: [PATCH 09/12] checkpoint
---
.../Transforms/VectorEmulateNarrowType.cpp | 24 +++++++++----------
1 file changed, 11 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 1d4c546db0bf0e..c091aec6246988 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -330,10 +330,6 @@ static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc,
static void atomicStore(OpBuilder &builder, Location loc,
MemRefValue linearizedMemref, Value linearizedIndex,
VectorValue valueToStore, Value mask) {
- // `numSrcElemsPerDest` is not used in this function, but to keep the function
- // signature consistent with `rmwStore` so as to simplify the pattern to
- // invoke.
-
assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
auto atomicOp = builder.create<memref::GenericAtomicRMWOp>(
loc, linearizedMemref, ValueRange{linearizedIndex});
@@ -342,17 +338,19 @@ static void atomicStore(OpBuilder &builder, Location loc,
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(atomicOp.getBody());
+ // Construct the vector value from the scalar.
auto oneVectorType = VectorType::get({1}, origValue.getType());
- auto fromElem = builder.create<vector::FromElementsOp>(loc, oneVectorType,
+ Value origVecValue = builder.create<vector::FromElementsOp>(loc, oneVectorType,
ValueRange{origValue});
- auto vectorBitCast =
- builder.create<vector::BitCastOp>(loc, valueToStore.getType(), fromElem);
-
- auto select =
- builder.create<arith::SelectOp>(loc, mask, valueToStore, vectorBitCast);
- auto bitcast2 = builder.create<vector::BitCastOp>(loc, oneVectorType, select);
- auto extract = builder.create<vector::ExtractOp>(loc, bitcast2, 0);
- builder.create<memref::AtomicYieldOp>(loc, extract.getResult());
+ origVecValue =
+ builder.create<vector::BitCastOp>(loc, valueToStore.getType(), origVecValue);
+
+ // Construct the final masked value and yield it.
+ Value maskedValue =
+ builder.create<arith::SelectOp>(loc, mask, valueToStore, origVecValue);
+ maskedValue = builder.create<vector::BitCastOp>(loc, oneVectorType, maskedValue);
+ auto scalarMaskedValue = builder.create<vector::ExtractOp>(loc, maskedValue, 0);
+ builder.create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
}
/// Generate a non-atomic read-modify-write sequence for subbyte storing.
>From 2b90bf73f90cb846f1f881a4dfa2fcce12179ee1 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 10 Dec 2024 22:01:30 +0800
Subject: [PATCH 10/12] updates
---
.../Transforms/VectorEmulateNarrowType.cpp | 37 ++++++++++++-------
1 file changed, 23 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index c091aec6246988..95eab0998ee9df 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -331,6 +331,9 @@ static void atomicStore(OpBuilder &builder, Location loc,
MemRefValue linearizedMemref, Value linearizedIndex,
VectorValue valueToStore, Value mask) {
assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
+
+ // Create an atomic load-modify-write region using
+ // `memref.generic_atomic_rmw`.
auto atomicOp = builder.create<memref::GenericAtomicRMWOp>(
loc, linearizedMemref, ValueRange{linearizedIndex});
Value origValue = atomicOp.getCurrentValue();
@@ -340,16 +343,18 @@ static void atomicStore(OpBuilder &builder, Location loc,
// Construct the vector value from the scalar.
auto oneVectorType = VectorType::get({1}, origValue.getType());
- Value origVecValue = builder.create<vector::FromElementsOp>(loc, oneVectorType,
- ValueRange{origValue});
- origVecValue =
- builder.create<vector::BitCastOp>(loc, valueToStore.getType(), origVecValue);
+ Value origVecValue = builder.create<vector::FromElementsOp>(
+ loc, oneVectorType, ValueRange{origValue});
+ origVecValue = builder.create<vector::BitCastOp>(loc, valueToStore.getType(),
+ origVecValue);
// Construct the final masked value and yield it.
Value maskedValue =
builder.create<arith::SelectOp>(loc, mask, valueToStore, origVecValue);
- maskedValue = builder.create<vector::BitCastOp>(loc, oneVectorType, maskedValue);
- auto scalarMaskedValue = builder.create<vector::ExtractOp>(loc, maskedValue, 0);
+ maskedValue =
+ builder.create<vector::BitCastOp>(loc, oneVectorType, maskedValue);
+ auto scalarMaskedValue =
+ builder.create<vector::ExtractOp>(loc, maskedValue, 0);
builder.create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
}
@@ -359,17 +364,21 @@ static void rmwStore(OpBuilder &builder, Location loc,
MemRefValue linearizedMemref, Value linearizedIndex,
VectorValue valueToStore, Value mask) {
assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
+
+ // Construct the vector value from the scalar.
auto oneElemVecType =
VectorType::get({1}, linearizedMemref.getType().getElementType());
- auto origValue = builder.create<vector::LoadOp>(
+ Value origVecValue = builder.create<vector::LoadOp>(
loc, oneElemVecType, linearizedMemref, ValueRange{linearizedIndex});
- auto castedValue =
- builder.create<vector::BitCastOp>(loc, valueToStore.getType(), origValue);
- auto result =
- builder.create<arith::SelectOp>(loc, mask, castedValue, valueToStore);
- auto resultBitcast =
- builder.create<vector::BitCastOp>(loc, oneElemVecType, result);
- builder.create<vector::StoreOp>(loc, resultBitcast, linearizedMemref,
+ origVecValue = builder.create<vector::BitCastOp>(loc, valueToStore.getType(),
+ origVecValue);
+
+ // Construct the final masked value and yield it.
+ Value maskedValue =
+ builder.create<arith::SelectOp>(loc, mask, origVecValue, valueToStore);
+ maskedValue =
+ builder.create<vector::BitCastOp>(loc, oneElemVecType, maskedValue);
+ builder.create<vector::StoreOp>(loc, maskedValue, linearizedMemref,
linearizedIndex);
}
>From 9fba9ebe54da51cc68c97f91df2fbe1fe5eabcfc Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 11 Dec 2024 00:01:17 +0800
Subject: [PATCH 11/12] adding missing parts of test
---
...late-narrow-type-unaligned-non-atomic.mlir | 28 ++++++++---
.../vector-emulate-narrow-type-unaligned.mlir | 50 +++++++++++++++++++
2 files changed, 72 insertions(+), 6 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
index fa4d9cb5e4d4c7..512feda8f831cd 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
@@ -4,17 +4,21 @@
// memref.alloc exists here because sub-byte vector data types such as i2
// are currently not supported as input arguments.
-func.func @vector_store_i2_const_rmw(%arg0: vector<3xi2>) {
+func.func @vector_store_i2_const_index_two_rmw(%arg0: vector<3xi2>) {
%0 = memref.alloc() : memref<3x3xi2>
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
return
}
-// CHECK: func @vector_store_i2_const_rmw(
+// Load from bit [12:18), byte [1:2] of total 3 bytes, both bytes needs rmw.
+
+// CHECK: func @vector_store_i2_const_index_two_rmw(
// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
// CHECK: %[[C1:.+]] = arith.constant 1 : index
+
+// Part 1 RMW sequence
// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]>
// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
@@ -22,16 +26,28 @@ func.func @vector_store_i2_const_rmw(%arg0: vector<3xi2>) {
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2>
// CHECK: %[[LOAD:.+]] = vector.load
-
-// Actual part to do RMW sequence
// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[UPCAST]], %[[INSERT]]
// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C1]]]
+// Part 2 RMW sequence
+// CHECK: %[[OFFSET:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]], %[[CST0]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[CST1:.+]] = arith.constant dense<[true, false, false, false]> : vector<4xi1>
+// CHECK: %[[LOAD2:.+]] = vector.load
+// CHECK: %[[UPCAST2:.+]] = vector.bitcast %[[LOAD2]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[UPCAST2]], %[[INSERT2]]
+// CHECK: %[[DOWNCAST2:.+]] = vector.bitcast %[[SELECT2]]
+// CHECK: vector.store %[[DOWNCAST2]], %[[ALLOC]][%[[OFFSET]]]
+
+
// -----
-func.func @vector_store_i2_atomic(%arg0: vector<7xi2>) {
+func.func @vector_store_i2_rmw(%arg0: vector<7xi2>) {
%0 = memref.alloc() : memref<3x7xi2>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
@@ -39,7 +55,7 @@ func.func @vector_store_i2_atomic(%arg0: vector<7xi2>) {
return
}
-// CHECK: func @vector_store_i2_atomic(
+// CHECK: func @vector_store_i2_rmw(
// CHECK-SAME: %[[ARG0:.+]]:
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8>
// CHECK: %[[C1:.+]] = arith.constant 1 : index
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index e37053c25ff066..51281056c023e0 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -341,6 +341,56 @@ func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>)
/// vector.store
///----------------------------------------------------------------------------------------
+func.func @vector_store_i2_const_index_two_atomic(%arg0: vector<3xi2>) {
+ %0 = memref.alloc() : memref<3x3xi2>
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
+ return
+}
+
+// In this example, emit 2 atomic RMWs.
+// Load from bit [12:18), byte [1:2] of total 3 bytes, both bytes needs rmw.
+
+// CHECK-LABEL: func @vector_store_i2_const_index_two_atomic(
+// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]> : vector<4xi1>
+// CHECK: %[[CST_0:.+]] = arith.constant dense<0> : vector<4xi2>
+
+// Part 1 atomic RMW sequence
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST_0]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<3xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
+
+// Part 2 atomic RMW sequence
+// CHECK: %[[ADDR2:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT3]], %[[CST_0]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[CST1:.+]] = arith.constant dense<[true, false, false, false]> : vector<4xi1>
+// CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDR2]]] : memref<3xi8> {
+// CHECK: %[[ARG2:.+]]: i8):
+// CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8>
+// CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST4]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST5:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT4:.+]] = vector.extract %[[BITCAST5]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT4]] : i8
+
+// -----
+
func.func @vector_store_i2_atomic(%arg0: vector<7xi2>) {
%0 = memref.alloc() : memref<3x7xi2>
%c0 = arith.constant 0 : index
>From 69480baad9efd164e4a1197df39956c84f373316 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 11 Dec 2024 21:05:19 +0800
Subject: [PATCH 12/12] add a TODO
---
mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 95eab0998ee9df..917c59f811ea98 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1809,6 +1809,7 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
RewritePatternSet &patterns, bool useAtomicWrites) {
// Populate `vector.*` load conversion patterns.
+ // TODO: #119553 support atomicity
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
ConvertVectorMaskedStore, ConvertVectorTransferRead>(
typeConverter, patterns.getContext());
More information about the Mlir-commits
mailing list