[Mlir-commits] [mlir] [MLIR] Support non-atomic RMW option for emulated vector stores (PR #124887)
Alan Li
llvmlistbot at llvm.org
Tue Feb 4 05:21:07 PST 2025
https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/124887
>From 3e2e4b513241ef47405a252532ba1352f12df04a Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 29 Jan 2025 05:14:11 +0000
Subject: [PATCH 1/6] First commit
---
.../Vector/Transforms/VectorRewritePatterns.h | 6 +-
.../Transforms/VectorEmulateNarrowType.cpp | 59 +++++++++++++++++--
.../Dialect/MemRef/TestEmulateNarrowType.cpp | 8 ++-
3 files changed, 66 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index a59f06f3c1ef1b..43478aacb50a14 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -364,10 +364,12 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
/// Appends patterns for emulating vector operations over narrow types with ops
-/// over wider types.
+/// over wider types. The `useAtomicWrites` indicates whether to use
+/// op `memref.generic_atomic_rmw` to perform atomic subbyte storing, or just a
+/// rmw sequence otherwise.
void populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
- RewritePatternSet &patterns);
+ RewritePatternSet &patterns, bool useAtomicWrites = true);
/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
/// vector operations comprising `shuffle` and `bitwise` ops.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 7ca88f1e0a0df9..8317317edb915f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -363,6 +363,29 @@ static void atomicStore(OpBuilder &builder, Location loc,
builder.create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
}
+/// Generate a non-atomic read-modify-write sequence for subbyte storing.
+/// It has similar logic to `atomicStore`, but without the atomicity.
+static void rmwStore(OpBuilder &builder, Location loc,
+ MemRefValue linearizedMemref, Value linearizedIndex,
+ VectorValue valueToStore, Value mask) {
+ assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
+
+ // Load the original value from memory, and cast it to the original element
+ // type.
+ auto oneElemVecType =
+ VectorType::get({1}, linearizedMemref.getType().getElementType());
+ Value origVecValue = builder.create<vector::LoadOp>(
+ loc, oneElemVecType, linearizedMemref, ValueRange{linearizedIndex});
+ origVecValue = builder.create<vector::BitCastOp>(loc, valueToStore.getType(),
+ origVecValue);
+
+ // Construct the final masked value and yield it.
+ Value maskedValue = selectAndCast(builder, loc, oneElemVecType, mask,
+ origVecValue, valueToStore);
+ builder.create<vector::StoreOp>(loc, maskedValue, linearizedMemref,
+ linearizedIndex);
+}
+
/// Extract `sliceNumElements` from source `vector` at `extractOffset`,
/// and insert it into an empty vector at `insertOffset`.
/// Inputs:
@@ -405,6 +428,10 @@ namespace {
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
using OpConversionPattern::OpConversionPattern;
+ ConvertVectorStore(MLIRContext *context, bool useAtomicWrites)
+ : OpConversionPattern<vector::StoreOp>(context),
+ useAtomicWrites_(useAtomicWrites) {}
+
LogicalResult
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -611,13 +638,31 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto backMask = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
- atomicStore(rewriter, loc, memrefBase, currentDestIndex,
- cast<VectorValue>(subWidthStorePart), backMask.getResult());
+ subEmulatedWidthStore(rewriter, loc, memrefBase, currentDestIndex,
+ cast<VectorValue>(subWidthStorePart),
+ backMask.getResult());
}
rewriter.eraseOp(op);
return success();
}
+
+ /// Store a subbyte-sized value to memory, with a mask. Depending on the
+ /// configuration, it could be an atomic store or a non-atomic RMW sequence.
+ template <typename... Args>
+ void subEmulatedWidthStore(Args &&...args) const {
+ static_assert(
+ std::is_same_v<decltype(atomicStore), decltype(rmwStore)> &&
+ "`atomicStore` and `rmwStore` must have same signature, as per "
+ "the design to keep the code clean, which one to call is "
+ "determined by the `useAtomicWrites` flag.");
+ std::function<decltype(atomicStore)> storeFunc =
+ useAtomicWrites_ ? atomicStore : rmwStore;
+ storeFunc(std::forward<Args>(args)...);
+ }
+
+private:
+ const bool useAtomicWrites_;
};
//===----------------------------------------------------------------------===//
@@ -1930,12 +1975,18 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
void vector::populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
- RewritePatternSet &patterns) {
+ RewritePatternSet &patterns, bool useAtomicWrites) {
// Populate `vector.*` conversion patterns.
- patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
+ // TODO: #119553 support atomicity
+ patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
ConvertVectorMaskedStore, ConvertVectorTransferRead>(
typeConverter, patterns.getContext());
+
+ // Populate `vector.*` store conversion patterns. The caller can choose
+ // to avoid emitting atomic operations and reduce it to load-modify-write
+ // sequence for stores if it is known there are no thread contentions.
+ patterns.insert<ConvertVectorStore>(patterns.getContext(), useAtomicWrites);
}
void vector::populateVectorNarrowTypeRewritePatterns(
diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
index 7401e470ed4f2c..9a3fac623fbd7d 100644
--- a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
+++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
@@ -99,7 +99,8 @@ struct TestEmulateNarrowTypePass
arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
- vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
+ vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns,
+ atomicStore);
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
@@ -118,6 +119,11 @@ struct TestEmulateNarrowTypePass
*this, "skip-memref-type-conversion",
llvm::cl::desc("disable memref type conversion (to test failures)"),
llvm::cl::init(false)};
+
+ Option<bool> atomicStore{
+ *this, "atomic-store",
+ llvm::cl::desc("use atomic store instead of load-modify-write"),
+ llvm::cl::init(true)};
};
} // namespace
>From 66ecff4e0487f8520e1591db702e30dd8b732ca3 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 29 Jan 2025 05:47:51 +0000
Subject: [PATCH 2/6] updates
---
.../Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 8317317edb915f..82d8a6ffcc17cc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -380,8 +380,9 @@ static void rmwStore(OpBuilder &builder, Location loc,
origVecValue);
// Construct the final masked value and yield it.
- Value maskedValue = selectAndCast(builder, loc, oneElemVecType, mask,
- origVecValue, valueToStore);
+ Value maskedValue =
+ downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
+ oneElemVecType, mask, valueToStore, origVecValue);
builder.create<vector::StoreOp>(loc, maskedValue, linearizedMemref,
linearizedIndex);
}
>From f2d5e8ba19625a90d5f32cda3c9cd337a36c339d Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 29 Jan 2025 06:16:54 +0000
Subject: [PATCH 3/6] linting
---
.../Transforms/VectorEmulateNarrowType.cpp | 6 +-
...late-narrow-type-unaligned-non-atomic.mlir | 119 ++++++++++++++++++
2 files changed, 122 insertions(+), 3 deletions(-)
create mode 100644 mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 82d8a6ffcc17cc..c848d3c0ca98aa 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -429,7 +429,7 @@ namespace {
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
using OpConversionPattern::OpConversionPattern;
- ConvertVectorStore(MLIRContext *context, bool useAtomicWrites)
+ ConvertVectorStore(MLIRContext *context, bool useAtomicWrites)
: OpConversionPattern<vector::StoreOp>(context),
useAtomicWrites_(useAtomicWrites) {}
@@ -583,8 +583,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
extractSliceIntoByte(rewriter, loc, valueToStore, 0,
frontSubWidthStoreElem, *foldedNumFrontPadElems);
- atomicStore(rewriter, loc, memrefBase, currentDestIndex,
- cast<VectorValue>(value), frontMask.getResult());
+ subEmulatedWidthStore(rewriter, loc, memrefBase, currentDestIndex,
+ cast<VectorValue>(value), frontMask.getResult());
}
if (currentSourceIndex >= origElements) {
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
new file mode 100644
index 00000000000000..79f8869d043ee3
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
@@ -0,0 +1,119 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8 atomic-store=false" --cse --split-input-file %s | FileCheck %s
+
+// TODO: remove memref.alloc() in the tests to eliminate noises.
+// memref.alloc exists here because sub-byte vector data types such as i2
+// are currently not supported as input arguments.
+
+func.func @vector_store_i2_const_index_two_rmw(%arg0: vector<3xi2>) {
+ %0 = memref.alloc() : memref<3x3xi2>
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
+ return
+}
+// Load from bit [12:18), byte [1:2] of total 3 bytes, both bytes needs rmw.
+
+// CHECK: func @vector_store_i2_const_index_two_rmw(
+// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+
+// Part 1 RMW sequence
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[LOAD:.+]] = vector.load
+// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]]
+// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
+// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C1]]]
+
+// Part 2 RMW sequence
+// CHECK: %[[OFFSET:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]], %[[CST0]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[CST1:.+]] = arith.constant dense<[true, false, false, false]> : vector<4xi1>
+// CHECK: %[[LOAD2:.+]] = vector.load
+// CHECK: %[[UPCAST2:.+]] = vector.bitcast %[[LOAD2]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[UPCAST2]]
+// CHECK: %[[DOWNCAST2:.+]] = vector.bitcast %[[SELECT2]]
+// CHECK: vector.store %[[DOWNCAST2]], %[[ALLOC]][%[[OFFSET]]]
+
+
+// -----
+
+func.func @vector_store_i2_rmw(%arg0: vector<7xi2>) {
+ %0 = memref.alloc() : memref<3x7xi2>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2>
+ return
+}
+
+// CHECK: func @vector_store_i2_rmw(
+// CHECK-SAME: %[[ARG0:.+]]:
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]}
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [3], strides = [1]}
+// First sub-width RMW:
+// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C1]]]
+// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]]
+// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
+// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C1]]]
+
+// Full-width store:
+// CHECK: %[[INDEX:.+]] = arith.addi %[[C1]], %[[C1]]
+// CHECK: %[[EXTRACT1:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]}
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EXTRACT1]]
+// CHECK: vector.store %[[BITCAST]], %[[ALLOC]][%[[INDEX]]]
+
+// Second sub-width RMW:
+// CHECK: %[[INDEX2:.+]] = arith.addi %[[INDEX]], %[[C1]]
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]}
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]]
+// CHECK-SAME: {offsets = [0], strides = [1]}
+// CHECK: %[[CST1:.+]] = arith.constant dense<[true, true, false, false]>
+// CHECK: %[[LOAD1:.+]] = vector.load %[[ALLOC]][%[[INDEX2]]]
+// CHECK: %[[UPCAST1:.+]] = vector.bitcast %[[LOAD1]]
+// CHECK: %[[SELECT1:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[UPCAST1]]
+// CHECK: %[[DOWNCAST1:.+]] = vector.bitcast %[[SELECT1]]
+// CHECK: vector.store %[[DOWNCAST1]], %[[ALLOC]][%[[INDEX2]]]
+
+// -----
+
+func.func @vector_store_i2_single_rmw(%arg0: vector<1xi2>) {
+ %0 = memref.alloc() : memref<4x1xi2>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2>
+ return
+}
+
+// in this test, only emit 1 rmw store
+// CHECK: func @vector_store_i2_single_rmw(
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, true, false, false]>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]]
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C0]]] : memref<1xi8>, vector<1xi8>
+// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]]
+// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
+// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C0]]]
+
>From c75f899316d9dd5bbd362569b145fcdd54fafeab Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 29 Jan 2025 09:20:33 +0000
Subject: [PATCH 4/6] update comments
---
.../Transforms/VectorEmulateNarrowType.cpp | 5 +----
...late-narrow-type-unaligned-non-atomic.mlir | 22 +++++++++++--------
2 files changed, 14 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index c848d3c0ca98aa..00019d8c2d4bc0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -364,14 +364,12 @@ static void atomicStore(OpBuilder &builder, Location loc,
}
/// Generate a non-atomic read-modify-write sequence for subbyte storing.
-/// It has similar logic to `atomicStore`, but without the atomicity.
+/// It has similar logic to `atomicStore`, but without atomicity.
static void rmwStore(OpBuilder &builder, Location loc,
MemRefValue linearizedMemref, Value linearizedIndex,
VectorValue valueToStore, Value mask) {
assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
- // Load the original value from memory, and cast it to the original element
- // type.
auto oneElemVecType =
VectorType::get({1}, linearizedMemref.getType().getElementType());
Value origVecValue = builder.create<vector::LoadOp>(
@@ -379,7 +377,6 @@ static void rmwStore(OpBuilder &builder, Location loc,
origVecValue = builder.create<vector::BitCastOp>(loc, valueToStore.getType(),
origVecValue);
- // Construct the final masked value and yield it.
Value maskedValue =
downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
oneElemVecType, mask, valueToStore, origVecValue);
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
index 79f8869d043ee3..84cae7d922b38a 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
@@ -4,16 +4,18 @@
// memref.alloc exists here because sub-byte vector data types such as i2
// are currently not supported as input arguments.
-func.func @vector_store_i2_const_index_two_rmw(%arg0: vector<3xi2>) {
+func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) {
%0 = memref.alloc() : memref<3x3xi2>
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
return
}
-// Load from bit [12:18), byte [1:2] of total 3 bytes, both bytes needs rmw.
+// In this example, emit two RMW stores without full-width store.
+// Store bit [12:18), byte [1:2] to a 3-byte vector, both bytes are
+// accessed partially.
-// CHECK: func @vector_store_i2_const_index_two_rmw(
+// CHECK: func @vector_store_i2_const_index_two_partial_stores(
// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
// CHECK: %[[C1:.+]] = arith.constant 1 : index
@@ -47,7 +49,7 @@ func.func @vector_store_i2_const_index_two_rmw(%arg0: vector<3xi2>) {
// -----
-func.func @vector_store_i2_rmw(%arg0: vector<7xi2>) {
+func.func @vector_store_i2_two_partial_one_full_stores(%arg0: vector<7xi2>) {
%0 = memref.alloc() : memref<3x7xi2>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
@@ -55,7 +57,9 @@ func.func @vector_store_i2_rmw(%arg0: vector<7xi2>) {
return
}
-// CHECK: func @vector_store_i2_rmw(
+// In this example, emit two RMW stores and one full-width store.
+
+// CHECK: func @vector_store_i2_two_partial_one_full_stores(
// CHECK-SAME: %[[ARG0:.+]]:
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8>
// CHECK: %[[C1:.+]] = arith.constant 1 : index
@@ -94,7 +98,7 @@ func.func @vector_store_i2_rmw(%arg0: vector<7xi2>) {
// -----
-func.func @vector_store_i2_single_rmw(%arg0: vector<1xi2>) {
+func.func @vector_store_i2_one_partial_store(%arg0: vector<1xi2>) {
%0 = memref.alloc() : memref<4x1xi2>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
@@ -102,8 +106,9 @@ func.func @vector_store_i2_single_rmw(%arg0: vector<1xi2>) {
return
}
-// in this test, only emit 1 rmw store
-// CHECK: func @vector_store_i2_single_rmw(
+// in this test, only emit partial RMW store as the store is within one byte.
+
+// CHECK: func @vector_store_i2_one_partial_store(
// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>)
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8>
// CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -116,4 +121,3 @@ func.func @vector_store_i2_single_rmw(%arg0: vector<1xi2>) {
// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]]
// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C0]]]
-
>From 562d87e3ce1fe0f2279ee1ca4e74d8f140670597 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Mon, 3 Feb 2025 21:39:27 -0800
Subject: [PATCH 5/6] Rename
---
.../Transforms/VectorEmulateNarrowType.cpp | 35 +++++++------------
1 file changed, 12 insertions(+), 23 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 00019d8c2d4bc0..edc8881d6919ef 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -334,9 +334,9 @@ static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc,
///
/// Result:
/// linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
-static void atomicStore(OpBuilder &builder, Location loc,
- MemRefValue linearizedMemref, Value storeIdx,
- VectorValue valueToStore, Value mask) {
+static void atomicRMWStore(OpBuilder &builder, Location loc,
+ MemRefValue linearizedMemref, Value storeIdx,
+ VectorValue valueToStore, Value mask) {
assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
// Create an atomic load-modify-write region using
@@ -364,10 +364,11 @@ static void atomicStore(OpBuilder &builder, Location loc,
}
/// Generate a non-atomic read-modify-write sequence for subbyte storing.
-/// It has similar logic to `atomicStore`, but without atomicity.
-static void rmwStore(OpBuilder &builder, Location loc,
- MemRefValue linearizedMemref, Value linearizedIndex,
- VectorValue valueToStore, Value mask) {
+/// It has similar logic to `atomicRMWStore`, but without atomicity.
+static void nonAtomicRMWStore(OpBuilder &builder, Location loc,
+ MemRefValue linearizedMemref,
+ Value linearizedIndex, VectorValue valueToStore,
+ Value mask) {
assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
auto oneElemVecType =
@@ -580,8 +581,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
extractSliceIntoByte(rewriter, loc, valueToStore, 0,
frontSubWidthStoreElem, *foldedNumFrontPadElems);
- subEmulatedWidthStore(rewriter, loc, memrefBase, currentDestIndex,
- cast<VectorValue>(value), frontMask.getResult());
+ auto storeFunc = useAtomicWrites_ ? atomicRMWStore : nonAtomicRMWStore;
+
+ storeFunc(rewriter, loc, memrefBase, currentDestIndex,
+ cast<VectorValue>(value), frontMask.getResult());
}
if (currentSourceIndex >= origElements) {
@@ -645,20 +648,6 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
return success();
}
- /// Store a subbyte-sized value to memory, with a mask. Depending on the
- /// configuration, it could be an atomic store or a non-atomic RMW sequence.
- template <typename... Args>
- void subEmulatedWidthStore(Args &&...args) const {
- static_assert(
- std::is_same_v<decltype(atomicStore), decltype(rmwStore)> &&
- "`atomicStore` and `rmwStore` must have same signature, as per "
- "the design to keep the code clean, which one to call is "
- "determined by the `useAtomicWrites` flag.");
- std::function<decltype(atomicStore)> storeFunc =
- useAtomicWrites_ ? atomicStore : rmwStore;
- storeFunc(std::forward<Args>(args)...);
- }
-
private:
const bool useAtomicWrites_;
};
>From 4fbbcbe01f465b5d682013f7322c888a77ddc2bd Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 4 Feb 2025 05:15:27 -0800
Subject: [PATCH 6/6] Update name
---
.../Vector/Transforms/VectorRewritePatterns.h | 8 +++----
.../Transforms/VectorEmulateNarrowType.cpp | 22 +++++++++----------
.../Dialect/MemRef/TestEmulateNarrowType.cpp | 11 +++++-----
3 files changed, 21 insertions(+), 20 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 43478aacb50a14..7de4a6a3157506 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -364,12 +364,12 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
/// Appends patterns for emulating vector operations over narrow types with ops
-/// over wider types. The `useAtomicWrites` indicates whether to use
-/// op `memref.generic_atomic_rmw` to perform atomic subbyte storing, or just a
-/// rmw sequence otherwise.
+/// over wider types. The `disableAtomicRMW` indicates whether to use a normal
+/// read-modify-write sequence instead of using `memref.generic_atomic_rmw` to
+/// perform subbyte storing.
void populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
- RewritePatternSet &patterns, bool useAtomicWrites = true);
+ RewritePatternSet &patterns, bool disableAtomicRMW = false);
/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
/// vector operations comprising `shuffle` and `bitwise` ops.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index edc8881d6919ef..ba891de833ad2c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -334,9 +334,9 @@ static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc,
///
/// Result:
/// linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
-static void atomicRMWStore(OpBuilder &builder, Location loc,
- MemRefValue linearizedMemref, Value storeIdx,
- VectorValue valueToStore, Value mask) {
+static void atomicRMW(OpBuilder &builder, Location loc,
+ MemRefValue linearizedMemref, Value storeIdx,
+ VectorValue valueToStore, Value mask) {
assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
// Create an atomic load-modify-write region using
@@ -363,8 +363,8 @@ static void atomicRMWStore(OpBuilder &builder, Location loc,
builder.create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
}
-/// Generate a non-atomic read-modify-write sequence for subbyte storing.
-/// It has similar logic to `atomicRMWStore`, but without atomicity.
+/// Generate a non-atomic read-modify-write sequence for storing to the emulated
+/// type. It has similar logic to `atomicRMWStore`, but without atomicity.
static void nonAtomicRMWStore(OpBuilder &builder, Location loc,
MemRefValue linearizedMemref,
Value linearizedIndex, VectorValue valueToStore,
@@ -427,9 +427,9 @@ namespace {
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
using OpConversionPattern::OpConversionPattern;
- ConvertVectorStore(MLIRContext *context, bool useAtomicWrites)
+ ConvertVectorStore(MLIRContext *context, bool disableAtomicRMW)
: OpConversionPattern<vector::StoreOp>(context),
- useAtomicWrites_(useAtomicWrites) {}
+ disableAtomicRMW(disableAtomicRMW) {}
LogicalResult
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
@@ -581,7 +581,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
extractSliceIntoByte(rewriter, loc, valueToStore, 0,
frontSubWidthStoreElem, *foldedNumFrontPadElems);
- auto storeFunc = useAtomicWrites_ ? atomicRMWStore : nonAtomicRMWStore;
+ auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW;
storeFunc(rewriter, loc, memrefBase, currentDestIndex,
cast<VectorValue>(value), frontMask.getResult());
@@ -649,7 +649,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
}
private:
- const bool useAtomicWrites_;
+ const bool disableAtomicRMW;
};
//===----------------------------------------------------------------------===//
@@ -1962,7 +1962,7 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
void vector::populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
- RewritePatternSet &patterns, bool useAtomicWrites) {
+ RewritePatternSet &patterns, bool disableAtomicRMW) {
// Populate `vector.*` conversion patterns.
// TODO: #119553 support atomicity
@@ -1973,7 +1973,7 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
// Populate `vector.*` store conversion patterns. The caller can choose
// to avoid emitting atomic operations and reduce it to load-modify-write
// sequence for stores if it is known there are no thread contentions.
- patterns.insert<ConvertVectorStore>(patterns.getContext(), useAtomicWrites);
+ patterns.insert<ConvertVectorStore>(patterns.getContext(), disableAtomicRMW);
}
void vector::populateVectorNarrowTypeRewritePatterns(
diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
index 9a3fac623fbd7d..ba2ea40e83d96e 100644
--- a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
+++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
@@ -100,7 +100,7 @@ struct TestEmulateNarrowTypePass
arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns,
- atomicStore);
+ disableAtomicRMW);
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
@@ -120,10 +120,11 @@ struct TestEmulateNarrowTypePass
llvm::cl::desc("disable memref type conversion (to test failures)"),
llvm::cl::init(false)};
- Option<bool> atomicStore{
- *this, "atomic-store",
- llvm::cl::desc("use atomic store instead of load-modify-write"),
- llvm::cl::init(true)};
+ Option<bool> disableAtomicRMW{
+ *this, "disable-atomic-rmw",
+ llvm::cl::desc("disable atomic read-modify-write and prefer generating "
+ "normal sequence"),
+ llvm::cl::init(false)};
};
} // namespace
More information about the Mlir-commits
mailing list