[Mlir-commits] [mlir] [MLIR] Implement emulation of static indexing subbyte type vector stores (PR #115922)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 14 09:11:32 PST 2024
https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/115922
>From 1430e8458aaad35256f55a43013e4bf99040c454 Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Fri, 25 Oct 2024 15:19:42 +0000
Subject: [PATCH 1/2] Implement vector stores
---
.../Transforms/VectorEmulateNarrowType.cpp | 212 +++++++++++++++---
.../vector-emulate-narrow-type-unaligned.mlir | 133 +++++++++++
2 files changed, 313 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index e5f2a847994aee..ef9298fc09d739 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -33,6 +33,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
#include <cstdint>
@@ -157,13 +158,10 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
/// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
/// emitting `vector.extract_strided_slice`.
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
- VectorType extractType, Value source,
- int64_t frontOffset,
+ Value source, int64_t frontOffset,
int64_t subvecSize) {
auto vectorType = cast<VectorType>(source.getType());
- assert((vectorType.getRank() == 1 && extractType.getRank() == 1) &&
- "expected 1-D source and destination types");
- (void)vectorType;
+ assert(vectorType.getRank() == 1 && "expected 1-D source types");
assert(frontOffset + subvecSize <= vectorType.getNumElements() &&
"subvector out of bounds");
@@ -174,9 +172,12 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
auto offsets = rewriter.getI64ArrayAttr({frontOffset});
auto sizes = rewriter.getI64ArrayAttr({subvecSize});
auto strides = rewriter.getI64ArrayAttr({1});
+
+ auto resultVectorType =
+ VectorType::get({subvecSize}, vectorType.getElementType());
return rewriter
- .create<vector::ExtractStridedSliceOp>(loc, extractType, source, offsets,
- sizes, strides)
+ .create<vector::ExtractStridedSliceOp>(loc, resultVectorType, source,
+ offsets, sizes, strides)
->getResult(0);
}
@@ -185,12 +186,10 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
/// `vector.insert_strided_slice`.
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc,
Value src, Value dest, int64_t offset) {
- auto srcType = cast<VectorType>(src.getType());
- auto destType = cast<VectorType>(dest.getType());
+ [[maybe_unused]] auto srcType = cast<VectorType>(src.getType());
+ [[maybe_unused]] auto destType = cast<VectorType>(dest.getType());
assert(srcType.getRank() == 1 && destType.getRank() == 1 &&
"expected source and dest to be vector type");
- (void)srcType;
- (void)destType;
auto offsets = rewriter.getI64ArrayAttr({offset});
auto strides = rewriter.getI64ArrayAttr({1});
return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
@@ -257,6 +256,63 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
newLoad);
}
+static void nonAtomicStore(ConversionPatternRewriter &rewriter, Location loc,
+ Value memref, Value index, Value value) {
+ auto originType = dyn_cast<VectorType>(value.getType());
+ auto memrefElemType = dyn_cast<MemRefType>(memref.getType()).getElementType();
+ auto scale = memrefElemType.getIntOrFloatBitWidth() /
+ originType.getElementType().getIntOrFloatBitWidth();
+ auto storeType =
+ VectorType::get({originType.getNumElements() / scale}, memrefElemType);
+ auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType, value);
+ rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memref, index);
+}
+
+/// atomically store a subbyte-sized value to memory, with a mask.
+static Value atomicStore(OpBuilder &rewriter, Location loc,
+ Value emulatedMemref, Value emulatedIndex,
+ TypedValue<VectorType> value, Value mask,
+ int64_t scale) {
+ auto atomicOp = rewriter.create<memref::GenericAtomicRMWOp>(
+ loc, emulatedMemref, ValueRange{emulatedIndex});
+ OpBuilder builder =
+ OpBuilder::atBlockEnd(atomicOp.getBody(), rewriter.getListener());
+ Value origValue = atomicOp.getCurrentValue();
+
+ // i8 -> vector type <1xi8> then <1xi8> -> <scale x i.>
+ auto oneVectorType = VectorType::get({1}, origValue.getType());
+ auto fromElem = builder.create<vector::FromElementsOp>(loc, oneVectorType,
+ ValueRange{origValue});
+ auto vectorBitCast =
+ builder.create<vector::BitCastOp>(loc, value.getType(), fromElem);
+
+ auto select =
+ builder.create<arith::SelectOp>(loc, mask, value, vectorBitCast);
+ auto bitcast2 = builder.create<vector::BitCastOp>(loc, oneVectorType, select);
+ auto extract = builder.create<vector::ExtractOp>(loc, bitcast2, 0);
+ builder.create<memref::AtomicYieldOp>(loc, extract.getResult());
+ return atomicOp;
+}
+
+// Extract a slice of a vector, and insert it into a byte vector.
+static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
+ Location loc, TypedValue<VectorType> vector,
+ int64_t sliceOffset, int64_t sliceNumElements,
+ int64_t byteOffset) {
+ auto vectorElementType = vector.getType().getElementType();
+ assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
+ "vector element must be a valid sub-byte type");
+ auto scale = 8 / vectorElementType.getIntOrFloatBitWidth();
+ auto emptyByteVector = rewriter.create<arith::ConstantOp>(
+ loc, VectorType::get({scale}, vectorElementType),
+ rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType)));
+ auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
+ sliceOffset, sliceNumElements);
+ auto inserted = staticallyInsertSubvector(rewriter, loc, extracted,
+ emptyByteVector, byteOffset);
+ return inserted;
+}
+
namespace {
//===----------------------------------------------------------------------===//
@@ -277,7 +333,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto loc = op.getLoc();
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
- Type oldElementType = op.getValueToStore().getType().getElementType();
+ auto valueToStore = op.getValueToStore();
+ Type oldElementType = valueToStore.getType().getElementType();
Type newElementType = convertedType.getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
int dstBits = newElementType.getIntOrFloatBitWidth();
@@ -301,15 +358,15 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
// vector<4xi8>
- auto origElements = op.getValueToStore().getType().getNumElements();
- if (origElements % scale != 0)
- return failure();
+ auto origElements = valueToStore.getType().getNumElements();
+ bool isUnalignedEmulation = origElements % scale != 0;
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
OpFoldResult linearizedIndices;
- std::tie(std::ignore, linearizedIndices) =
+ memref::LinearizedMemRefInfo linearizedInfo;
+ std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
rewriter, loc, srcBits, dstBits,
stridedMetadata.getConstifiedMixedOffset(),
@@ -317,14 +374,108 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
stridedMetadata.getConstifiedMixedStrides(),
getAsOpFoldResult(adaptor.getIndices()));
- auto numElements = origElements / scale;
- auto bitCast = rewriter.create<vector::BitCastOp>(
- loc, VectorType::get(numElements, newElementType),
- op.getValueToStore());
+ auto foldedIntraVectorOffset =
+ isUnalignedEmulation
+ ? getConstantIntValue(linearizedInfo.intraDataOffset)
+ : 0;
+
+ if (!foldedIntraVectorOffset) {
+ // unimplemented case for dynamic front padding size
+ return failure();
+ }
+
+ // conditions when atomic stores and all that are not needed:
+ // 1. The source vector size is multiple of byte size
+ // 2. The address of the store is byte aligned
+ if (!isUnalignedEmulation && *foldedIntraVectorOffset == 0) {
+ auto numElements = origElements / scale;
+ auto bitCast = rewriter.create<vector::BitCastOp>(
+ loc, VectorType::get(numElements, newElementType),
+ op.getValueToStore());
+ rewriter.replaceOpWithNewOp<vector::StoreOp>(
+ op, bitCast.getResult(), adaptor.getBase(),
+ getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+ return llvm::success();
+ }
+
+ Value emulatedMemref = adaptor.getBase();
+ // the index into the target memref we are storing to
+ Value currentDestIndex =
+ getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
+ auto constantOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto atomicMaskType = VectorType::get({scale}, rewriter.getI1Type());
+ // the index into the source vector we are currently processing
+ auto currentSourceIndex = 0;
+
+ // 1. atomic store for the first byte
+ auto frontAtomicStoreElem = (scale - *foldedIntraVectorOffset) % scale;
+ if (frontAtomicStoreElem != 0) {
+ auto frontMaskValues = llvm::SmallVector<bool>(scale, false);
+ if (*foldedIntraVectorOffset + origElements < scale) {
+ std::fill_n(frontMaskValues.begin() + *foldedIntraVectorOffset,
+ origElements, true);
+ frontAtomicStoreElem = origElements;
+ } else {
+ std::fill_n(frontMaskValues.end() - frontAtomicStoreElem,
+ *foldedIntraVectorOffset, true);
+ }
+ auto frontMask = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(atomicMaskType, frontMaskValues));
+
+ currentSourceIndex = scale - (*foldedIntraVectorOffset);
+ auto value = extractSliceIntoByte(
+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore), 0,
+ frontAtomicStoreElem, *foldedIntraVectorOffset);
+
+ atomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
+ cast<TypedValue<VectorType>>(value), frontMask.getResult(),
+ scale);
+
+ currentDestIndex = rewriter.create<arith::AddIOp>(
+ loc, rewriter.getIndexType(), currentDestIndex, constantOne);
+ }
+
+ if (currentSourceIndex >= origElements) {
+ rewriter.eraseOp(op);
+ return success();
+ }
+
+ // 2. non-atomic store
+ int64_t nonAtomicStoreSize = (origElements - currentSourceIndex) / scale;
+ int64_t numNonAtomicElements = nonAtomicStoreSize * scale;
+ if (nonAtomicStoreSize != 0) {
+ auto nonAtomicStorePart = staticallyExtractSubvector(
+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
+ currentSourceIndex, numNonAtomicElements);
+
+ nonAtomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
+ nonAtomicStorePart);
+
+ currentSourceIndex += numNonAtomicElements;
+ currentDestIndex = rewriter.create<arith::AddIOp>(
+ loc, rewriter.getIndexType(), currentDestIndex,
+ rewriter.create<arith::ConstantIndexOp>(loc, nonAtomicStoreSize));
+ }
+
+ // 3. atomic store for the last byte
+ auto remainingElements = origElements - currentSourceIndex;
+ if (remainingElements != 0) {
+ auto atomicStorePart = extractSliceIntoByte(
+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
+ currentSourceIndex, remainingElements, 0);
+
+ // back mask
+ auto maskValues = llvm::SmallVector<bool>(scale, 0);
+ std::fill_n(maskValues.begin(), remainingElements, 1);
+ auto backMask = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(atomicMaskType, maskValues));
+
+ atomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
+ cast<TypedValue<VectorType>>(atomicStorePart),
+ backMask.getResult(), scale);
+ }
- rewriter.replaceOpWithNewOp<vector::StoreOp>(
- op, bitCast.getResult(), adaptor.getBase(),
- getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+ rewriter.eraseOp(op);
return success();
}
};
@@ -532,9 +683,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
linearizedInfo.intraDataOffset, origElements);
} else if (isUnalignedEmulation) {
- result =
- staticallyExtractSubvector(rewriter, loc, op.getType(), result,
- *foldedIntraVectorOffset, origElements);
+ result = staticallyExtractSubvector(
+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
rewriter.replaceOp(op, result);
return success();
@@ -693,9 +843,8 @@ struct ConvertVectorMaskedLoad final
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
op.getPassThru(), linearizedInfo.intraDataOffset, origElements);
} else if (isUnalignedEmulation) {
- result =
- staticallyExtractSubvector(rewriter, loc, op.getType(), result,
- *foldedIntraVectorOffset, origElements);
+ result = staticallyExtractSubvector(
+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
rewriter.replaceOp(op, result);
@@ -778,9 +927,8 @@ struct ConvertVectorTransferRead final
linearizedInfo.intraDataOffset,
origElements);
} else if (isUnalignedEmulation) {
- result =
- staticallyExtractSubvector(rewriter, loc, op.getType(), result,
- *foldedIntraVectorOffset, origElements);
+ result = staticallyExtractSubvector(
+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
rewriter.replaceOp(op, result);
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 7ed75ff7f1579c..0a007a7a26f559 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -249,3 +249,136 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
// CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2>
// CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2>
// CHECK: %[[IN9:.+]] = vector.insert %[[EX9]], %[[IN8]] [2] : i2 into vector<3xi2>
+
+// -----
+
+func.func @vector_store_i2_const(%arg0: vector<3xi2>) {
+ %0 = memref.alloc() : memref<3x3xi2>
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
+ return
+}
+
+// in this example, emit 2 atomic stores, with the first storing 1 element and the second storing 2 elements.
+// CHECK: func @vector_store_i2_const(
+// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+
+// atomic store of the first byte
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]> : vector<4xi1>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<3xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
+
+// atomic store of the second byte
+// CHECK: %[[ADDI:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]], %[[CST0]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDI]]] : memref<3xi8> {
+// CHECK: %[[ARG2:.+]]: i8):
+// CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8>
+// CHECK: %[[BITCAST3:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST3]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST4]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT3]] : i8
+
+// -----
+
+func.func @vector_store_i8_2(%arg0: vector<7xi2>) {
+ %0 = memref.alloc() : memref<3x7xi2>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2>
+ return
+}
+
+// in this example, emit 2 atomic stores and 1 non-atomic store
+
+// CHECK: func @vector_store_i8_2(
+// CHECK-SAME: %[[ARG0:.+]]: vector<7xi2>)
+// CHECK: %[[ALLOC]] = memref.alloc() : memref<6xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]> : vector<4xi1>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+
+// first atomic store
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]} : vector<7xi2> to vector<1xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [3], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<6xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
+
+// non atomic store part
+// CHECK: %[[ADDR:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]} : vector<7xi2> to vector<4xi2>
+// CHECK: %[[BITCAST3:.+]] = vector.bitcast %[[EXTRACT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: vector.store %[[BITCAST3]], %[[ALLOC]][%[[ADDR]]] : memref<6xi8>, vector<1xi8>
+
+// second atomic store
+// CHECK: %[[ADDR2:.+]] = arith.addi %[[ADDR]], %[[C1]] : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]} : vector<7xi2> to vector<2xi2>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT3]], %[[CST0]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDR2]]] : memref<6xi8> {
+// CHECK: %[[ARG2:.+]]: i8):
+// CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8>
+// CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST4]] :
+// CHECK-SAME: vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST5:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT4:.+]] = vector.extract %[[BITCAST5]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT4]] : i8
+
+// -----
+
+func.func @vector_store_i2_single_atomic(%arg0: vector<1xi2>) {
+ %0 = memref.alloc() : memref<4x1xi2>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2>
+ return
+}
+
+// in this example, only emit 1 atomic store
+// CHECK: func @vector_store_i2_single_atomic(
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, true, false, false]> : vector<4xi1>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]]
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2>
+
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C0]]] : memref<1xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
>From 9b81a3ffee2e3ead0372cbdf43fbb30371145997 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Thu, 14 Nov 2024 12:10:15 -0500
Subject: [PATCH 2/2] Add support to avoid atomic operations.
---
.../Vector/Transforms/VectorRewritePatterns.h | 5 +-
.../Transforms/VectorEmulateNarrowType.cpp | 148 +++++++++++-------
2 files changed, 97 insertions(+), 56 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index a59f06f3c1ef1b..64bb3a2204cfdc 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -364,10 +364,11 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
/// Appends patterns for emulating vector operations over narrow types with ops
-/// over wider types.
+/// over wider types. `useAtomicWrites` indicates whether to use atomic
+/// operations in the places where thread contention is possible.
void populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
- RewritePatternSet &patterns);
+ RewritePatternSet &patterns, bool useAtomicWrites = true);
/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
/// vector operations comprising `shuffle` and `bitwise` ops.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index ef9298fc09d739..08c379e85cebc8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -76,7 +76,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
int numSrcElemsPerDest,
int numFrontPadElems = 0) {
- assert(numFrontPadElems < numSrcElemsPerDest && "intraDataOffset must be less than scale");
+ assert(numFrontPadElems < numSrcElemsPerDest &&
+ "intraDataOffset must be less than scale");
auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
numSrcElemsPerDest;
@@ -256,23 +257,11 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
newLoad);
}
-static void nonAtomicStore(ConversionPatternRewriter &rewriter, Location loc,
- Value memref, Value index, Value value) {
- auto originType = dyn_cast<VectorType>(value.getType());
- auto memrefElemType = dyn_cast<MemRefType>(memref.getType()).getElementType();
- auto scale = memrefElemType.getIntOrFloatBitWidth() /
- originType.getElementType().getIntOrFloatBitWidth();
- auto storeType =
- VectorType::get({originType.getNumElements() / scale}, memrefElemType);
- auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType, value);
- rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memref, index);
-}
-
-/// atomically store a subbyte-sized value to memory, with a mask.
+/// Atomically store a subbyte-sized value to memory, with a mask.
static Value atomicStore(OpBuilder &rewriter, Location loc,
- Value emulatedMemref, Value emulatedIndex,
- TypedValue<VectorType> value, Value mask,
- int64_t scale) {
+ TypedValue<MemRefType> emulatedMemref,
+ Value emulatedIndex, TypedValue<VectorType> value,
+ Value mask, int64_t scale) {
auto atomicOp = rewriter.create<memref::GenericAtomicRMWOp>(
loc, emulatedMemref, ValueRange{emulatedIndex});
OpBuilder builder =
@@ -294,6 +283,27 @@ static Value atomicStore(OpBuilder &rewriter, Location loc,
return atomicOp;
}
+/// Generate a non-atomic read-modify-write sequence for subbyte storing.
+static Value rmwStore(OpBuilder &rewriter, Location loc,
+ TypedValue<MemRefType> emulatedMemref,
+ Value emulatedIndex, TypedValue<VectorType> value,
+ Value mask, int64_t numSrcElemsPerDest) {
+ auto emulatedIOType =
+ VectorType::get({1}, emulatedMemref.getType().getElementType());
+ auto elemLoad = rewriter.create<vector::LoadOp>(
+ loc, emulatedIOType, emulatedMemref, ValueRange{emulatedIndex});
+ auto fromBitcast = rewriter.create<vector::BitCastOp>(
+ loc,
+ VectorType::get({numSrcElemsPerDest}, value.getType().getElementType()),
+ elemLoad);
+ auto select = rewriter.create<arith::SelectOp>(loc, mask, value, fromBitcast);
+ auto toBitcast =
+ rewriter.create<vector::BitCastOp>(loc, emulatedIOType, select);
+ return rewriter
+ .create<vector::StoreOp>(loc, toBitcast, emulatedMemref, emulatedIndex)
+ ->getResult(0);
+}
+
// Extract a slice of a vector, and insert it into a byte vector.
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
Location loc, TypedValue<VectorType> vector,
@@ -322,6 +332,10 @@ namespace {
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
using OpConversionPattern::OpConversionPattern;
+ ConvertVectorStore(MLIRContext *context, bool useAtomicWrites)
+ : OpConversionPattern<vector::StoreOp>(context),
+ useAtomicWrites_(useAtomicWrites) {}
+
LogicalResult
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -343,7 +357,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
return rewriter.notifyMatchFailure(
op, "only dstBits % srcBits == 0 supported");
}
- int scale = dstBits / srcBits;
+ int numSrcElemsPerDest = dstBits / srcBits;
// Adjust the number of elements to store when emulating narrow types.
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -359,7 +373,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// vector<4xi8>
auto origElements = valueToStore.getType().getNumElements();
- bool isUnalignedEmulation = origElements % scale != 0;
+ bool isUnalignedEmulation = origElements % numSrcElemsPerDest != 0;
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -374,21 +388,21 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
stridedMetadata.getConstifiedMixedStrides(),
getAsOpFoldResult(adaptor.getIndices()));
- auto foldedIntraVectorOffset =
+ auto foldedNumFrontPadElems =
isUnalignedEmulation
? getConstantIntValue(linearizedInfo.intraDataOffset)
: 0;
- if (!foldedIntraVectorOffset) {
- // unimplemented case for dynamic front padding size
+ if (!foldedNumFrontPadElems) {
+ // Unimplemented case for dynamic front padding size != 0
return failure();
}
- // conditions when atomic stores and all that are not needed:
+ // Conditions when atomic stores and all that are not needed:
// 1. The source vector size is multiple of byte size
// 2. The address of the store is byte aligned
- if (!isUnalignedEmulation && *foldedIntraVectorOffset == 0) {
- auto numElements = origElements / scale;
+ if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0) {
+ auto numElements = origElements / numSrcElemsPerDest;
auto bitCast = rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numElements, newElementType),
op.getValueToStore());
@@ -398,38 +412,41 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
return llvm::success();
}
- Value emulatedMemref = adaptor.getBase();
- // the index into the target memref we are storing to
+ TypedValue<MemRefType> emulatedMemref =
+ cast<TypedValue<MemRefType>>(adaptor.getBase());
+ // The index into the target memref we are storing to
Value currentDestIndex =
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
auto constantOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto atomicMaskType = VectorType::get({scale}, rewriter.getI1Type());
- // the index into the source vector we are currently processing
+ auto atomicMaskType =
+ VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type());
+ // The index into the source vector we are currently processing
auto currentSourceIndex = 0;
- // 1. atomic store for the first byte
- auto frontAtomicStoreElem = (scale - *foldedIntraVectorOffset) % scale;
+ // 1. Atomic store for the first byte
+ auto frontAtomicStoreElem =
+ (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
if (frontAtomicStoreElem != 0) {
- auto frontMaskValues = llvm::SmallVector<bool>(scale, false);
- if (*foldedIntraVectorOffset + origElements < scale) {
- std::fill_n(frontMaskValues.begin() + *foldedIntraVectorOffset,
+ auto frontMaskValues = llvm::SmallVector<bool>(numSrcElemsPerDest, false);
+ if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
+ std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
origElements, true);
frontAtomicStoreElem = origElements;
} else {
std::fill_n(frontMaskValues.end() - frontAtomicStoreElem,
- *foldedIntraVectorOffset, true);
+ *foldedNumFrontPadElems, true);
}
auto frontMask = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(atomicMaskType, frontMaskValues));
- currentSourceIndex = scale - (*foldedIntraVectorOffset);
+ currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
auto value = extractSliceIntoByte(
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore), 0,
- frontAtomicStoreElem, *foldedIntraVectorOffset);
+ frontAtomicStoreElem, *foldedNumFrontPadElems);
- atomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
- cast<TypedValue<VectorType>>(value), frontMask.getResult(),
- scale);
+ subByteStore(rewriter, loc, emulatedMemref, currentDestIndex,
+ cast<TypedValue<VectorType>>(value), frontMask.getResult(),
+ numSrcElemsPerDest);
currentDestIndex = rewriter.create<arith::AddIOp>(
loc, rewriter.getIndexType(), currentDestIndex, constantOne);
@@ -440,16 +457,24 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
return success();
}
- // 2. non-atomic store
- int64_t nonAtomicStoreSize = (origElements - currentSourceIndex) / scale;
- int64_t numNonAtomicElements = nonAtomicStoreSize * scale;
+ // 2. Non-atomic store
+ int64_t nonAtomicStoreSize =
+ (origElements - currentSourceIndex) / numSrcElemsPerDest;
+ int64_t numNonAtomicElements = nonAtomicStoreSize * numSrcElemsPerDest;
if (nonAtomicStoreSize != 0) {
auto nonAtomicStorePart = staticallyExtractSubvector(
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
currentSourceIndex, numNonAtomicElements);
- nonAtomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
- nonAtomicStorePart);
+ auto originType = dyn_cast<VectorType>(nonAtomicStorePart.getType());
+ auto memrefElemType =
+ dyn_cast<MemRefType>(emulatedMemref.getType()).getElementType();
+ auto storeType = VectorType::get(
+ {originType.getNumElements() / numSrcElemsPerDest}, memrefElemType);
+ auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
+ nonAtomicStorePart);
+ rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), emulatedMemref,
+ currentDestIndex);
currentSourceIndex += numNonAtomicElements;
currentDestIndex = rewriter.create<arith::AddIOp>(
@@ -457,27 +482,37 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
rewriter.create<arith::ConstantIndexOp>(loc, nonAtomicStoreSize));
}
- // 3. atomic store for the last byte
+ // 3. Atomic store for the last byte
auto remainingElements = origElements - currentSourceIndex;
if (remainingElements != 0) {
auto atomicStorePart = extractSliceIntoByte(
rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
currentSourceIndex, remainingElements, 0);
- // back mask
- auto maskValues = llvm::SmallVector<bool>(scale, 0);
+ // Generate back mask
+ auto maskValues = llvm::SmallVector<bool>(numSrcElemsPerDest, 0);
std::fill_n(maskValues.begin(), remainingElements, 1);
auto backMask = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(atomicMaskType, maskValues));
- atomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
- cast<TypedValue<VectorType>>(atomicStorePart),
- backMask.getResult(), scale);
+ subByteStore(rewriter, loc, emulatedMemref, currentDestIndex,
+ cast<TypedValue<VectorType>>(atomicStorePart),
+ backMask.getResult(), numSrcElemsPerDest);
}
rewriter.eraseOp(op);
return success();
}
+
+ template <typename... Args>
+ Value subByteStore(Args &&...args) const {
+ std::function<decltype(atomicStore)> storeFunc =
+ useAtomicWrites_ ? atomicStore : rmwStore;
+ return storeFunc(std::forward<Args>(args)...);
+ }
+
+private:
+ const bool useAtomicWrites_;
};
//===----------------------------------------------------------------------===//
@@ -1673,12 +1708,17 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
void vector::populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
- RewritePatternSet &patterns) {
+ RewritePatternSet &patterns, bool useAtomicWrites) {
- // Populate `vector.*` conversion patterns.
- patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
+ // Populate `vector.*` load conversion patterns.
+ patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
ConvertVectorMaskedStore, ConvertVectorTransferRead>(
typeConverter, patterns.getContext());
+
+ // Populate `vector.*` store conversion patterns. The caller can choose
+ // to avoid emitting atomic operations and reduce it to load-modify-write
+ // sequence for stores if it is known there are no thread contentions.
+ patterns.insert<ConvertVectorStore>(patterns.getContext(), useAtomicWrites);
}
void vector::populateVectorNarrowTypeRewritePatterns(
More information about the Mlir-commits
mailing list