[Mlir-commits] [mlir] f0e1857 - [MLIR] Support non-atomic RMW option for emulated vector stores (#124887)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 6 13:22:46 PST 2025


Author: Alan Li
Date: 2025-02-06T13:22:42-08:00
New Revision: f0e1857c84186263ab3bd8b9420b54c8c5810136

URL: https://github.com/llvm/llvm-project/commit/f0e1857c84186263ab3bd8b9420b54c8c5810136
DIFF: https://github.com/llvm/llvm-project/commit/f0e1857c84186263ab3bd8b9420b54c8c5810136.diff

LOG: [MLIR] Support non-atomic RMW option for emulated vector stores (#124887)

This patch is a followup of the previous one: #115922, It adds an option
to turn on emitting non-atomic rmw code sequence instead of atomic rmw.

Added: 
    mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
    mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
    mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
    mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index a59f06f3c1ef1b3..7de4a6a3157506f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -364,10 +364,12 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
                                                PatternBenefit benefit = 1);
 
 /// Appends patterns for emulating vector operations over narrow types with ops
-/// over wider types.
+/// over wider types. The `disableAtomicRMW` indicates whether to use a normal
+/// read-modify-write sequence instead of using `memref.generic_atomic_rmw` to
+/// perform subbyte storing.
 void populateVectorNarrowTypeEmulationPatterns(
     const arith::NarrowTypeEmulationConverter &typeConverter,
-    RewritePatternSet &patterns);
+    RewritePatternSet &patterns, bool disableAtomicRMW = false);
 
 /// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
 /// vector operations comprising `shuffle` and `bitwise` ops.

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 28ccbfbb6962e9c..bf1ecd7d4559caf 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -342,9 +342,9 @@ static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc,
 ///
 /// Result:
 ///   linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
-static void atomicStore(OpBuilder &builder, Location loc,
-                        MemRefValue linearizedMemref, Value storeIdx,
-                        VectorValue valueToStore, Value mask) {
+static void atomicRMW(OpBuilder &builder, Location loc,
+                      MemRefValue linearizedMemref, Value storeIdx,
+                      VectorValue valueToStore, Value mask) {
   assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
 
   // Create an atomic load-modify-write region using
@@ -371,6 +371,27 @@ static void atomicStore(OpBuilder &builder, Location loc,
   builder.create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
 }
 
+/// Generate a non-atomic read-modify-write sequence for storing to the emulated
+/// type. It has similar logic to `atomicRMWStore`, but without atomicity.
+static void nonAtomicRMW(OpBuilder &builder, Location loc,
+                         MemRefValue linearizedMemref, Value linearizedIndex,
+                         VectorValue valueToStore, Value mask) {
+  assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
+
+  auto oneElemVecType =
+      VectorType::get({1}, linearizedMemref.getType().getElementType());
+  Value origVecValue = builder.create<vector::LoadOp>(
+      loc, oneElemVecType, linearizedMemref, ValueRange{linearizedIndex});
+  origVecValue = builder.create<vector::BitCastOp>(loc, valueToStore.getType(),
+                                                   origVecValue);
+
+  Value maskedValue =
+      downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
+                              oneElemVecType, mask, valueToStore, origVecValue);
+  builder.create<vector::StoreOp>(loc, maskedValue, linearizedMemref,
+                                  linearizedIndex);
+}
+
 /// Extract `sliceNumElements` from source `vector` at `extractOffset`,
 /// and insert it into an empty vector at `insertOffset`.
 /// Inputs:
@@ -415,6 +436,10 @@ namespace {
 struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
   using OpConversionPattern::OpConversionPattern;
 
+  ConvertVectorStore(MLIRContext *context, bool disableAtomicRMW)
+      : OpConversionPattern<vector::StoreOp>(context),
+        disableAtomicRMW(disableAtomicRMW) {}
+
   LogicalResult
   matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -544,6 +569,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     auto subWidthStoreMaskType =
         VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type());
 
+    auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW;
+
     // 1. Partial width store for the leading byte.
     // When the store address is not aligned to emulated width boundary, deal
     // with the unaligned part so that the rest elements are aligned to width
@@ -568,8 +595,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
           extractSliceIntoByte(rewriter, loc, valueToStore, 0,
                                frontSubWidthStoreElem, *foldedNumFrontPadElems);
 
-      atomicStore(rewriter, loc, memrefBase, currentDestIndex,
-                  cast<VectorValue>(value), frontMask.getResult());
+      storeFunc(rewriter, loc, memrefBase, currentDestIndex,
+                cast<VectorValue>(value), frontMask.getResult());
     }
 
     if (currentSourceIndex >= origElements) {
@@ -624,13 +651,16 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
       auto backMask = rewriter.create<arith::ConstantOp>(
           loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
 
-      atomicStore(rewriter, loc, memrefBase, currentDestIndex,
-                  cast<VectorValue>(subWidthStorePart), backMask.getResult());
+      storeFunc(rewriter, loc, memrefBase, currentDestIndex,
+                cast<VectorValue>(subWidthStorePart), backMask.getResult());
     }
 
     rewriter.eraseOp(op);
     return success();
   }
+
+private:
+  const bool disableAtomicRMW;
 };
 
 //===----------------------------------------------------------------------===//
@@ -1969,12 +1999,18 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
 // The emulated type is inferred from the converted memref type.
 void vector::populateVectorNarrowTypeEmulationPatterns(
     const arith::NarrowTypeEmulationConverter &typeConverter,
-    RewritePatternSet &patterns) {
+    RewritePatternSet &patterns, bool disableAtomicRMW) {
 
   // Populate `vector.*` conversion patterns.
-  patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
+  // TODO: #119553 support atomicity
+  patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
                ConvertVectorMaskedStore, ConvertVectorTransferRead>(
       typeConverter, patterns.getContext());
+
+  // Populate `vector.*` store conversion patterns. The caller can choose
+  // to avoid emitting atomic operations and reduce it to read-modify-write
+  // sequence for stores if it is known there are no thread contentions.
+  patterns.insert<ConvertVectorStore>(patterns.getContext(), disableAtomicRMW);
 }
 
 void vector::populateVectorNarrowTypeRewritePatterns(

diff  --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
new file mode 100644
index 000000000000000..1d6263535ae8015
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-non-atomic.mlir
@@ -0,0 +1,128 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8 disable-atomic-rmw=true" --cse --split-input-file %s | FileCheck %s
+
+// TODO: remove memref.alloc() in the tests to eliminate noises.
+// memref.alloc exists here because sub-byte vector data types such as i2
+// are currently not supported as input arguments.
+
+///----------------------------------------------------------------------------------------
+/// vector.store
+///----------------------------------------------------------------------------------------
+
+func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) {
+    %0 = memref.alloc() : memref<3x3xi2>
+    %c0 = arith.constant 0 : index
+    %c2 = arith.constant 2 : index
+    vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
+    return
+}
+
+// Emit two non-atomic RMW partial stores. Store 6 bits from the input vector (bits [12:18)),
+// into bytes [1:2] from a 3-byte output memref. Due to partial storing,
+// both bytes are accessed partially through masking.
+
+// CHECK: func @vector_store_i2_const_index_two_partial_stores(
+// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+
+// Part 1 RMW sequence
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[LOAD:.+]] = vector.load
+// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[DOWNCAST]]
+// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[SELECT]]
+// CHECK: vector.store %[[UPCAST]], %[[ALLOC]][%[[C1]]]
+
+// Part 2 RMW sequence
+// CHECK: %[[OFFSET:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]], %[[CST0]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[CST1:.+]] = arith.constant dense<[true, false, false, false]> : vector<4xi1>
+// CHECK: %[[LOAD2:.+]] = vector.load
+// CHECK: %[[UPCAST2:.+]] = vector.bitcast %[[LOAD2]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[UPCAST2]]
+// CHECK: %[[DOWNCAST2:.+]] = vector.bitcast %[[SELECT2]]
+// CHECK: vector.store %[[DOWNCAST2]], %[[ALLOC]][%[[OFFSET]]]
+
+
+// -----
+
+func.func @vector_store_i2_two_partial_one_full_stores(%arg0: vector<7xi2>) {
+    %0 = memref.alloc() : memref<3x7xi2>
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2>
+    return
+}
+
+// In this example, emit two RMW stores and one full-width store.
+
+// CHECK: func @vector_store_i2_two_partial_one_full_stores(
+// CHECK-SAME: %[[ARG0:.+]]:
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]}
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [3], strides = [1]}
+// First sub-width RMW:
+// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C1]]]
+// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]]
+// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
+// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C1]]]
+
+// Full-width store:
+// CHECK: %[[INDEX:.+]] = arith.addi %[[C1]], %[[C1]]
+// CHECK: %[[EXTRACT1:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]}
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EXTRACT1]]
+// CHECK: vector.store %[[BITCAST]], %[[ALLOC]][%[[INDEX]]]
+
+// Second sub-width RMW:
+// CHECK: %[[INDEX2:.+]] = arith.addi %[[INDEX]], %[[C1]]
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]}
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]]
+// CHECK-SAME: {offsets = [0], strides = [1]}
+// CHECK: %[[CST1:.+]] = arith.constant dense<[true, true, false, false]>
+// CHECK: %[[LOAD2:.+]] = vector.load %[[ALLOC]][%[[INDEX2]]]
+// CHECK: %[[UPCAST2:.+]] = vector.bitcast %[[LOAD2]]
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[UPCAST2]]
+// CHECK: %[[DOWNCAST2:.+]] = vector.bitcast %[[SELECT2]]
+// CHECK: vector.store %[[DOWNCAST2]], %[[ALLOC]][%[[INDEX2]]]
+
+// -----
+
+func.func @vector_store_i2_const_index_one_partial_store(%arg0: vector<1xi2>) {
+    %0 = memref.alloc() : memref<4x1xi2>
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2>
+    return
+}
+
+// in this test, only emit partial RMW store as the store is within one byte.
+
+// CHECK: func @vector_store_i2_const_index_one_partial_store(
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, true, false, false]>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]]
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C0]]] : memref<1xi8>, vector<1xi8>
+// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]]
+// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
+// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C0]]]

diff  --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 89cb8e0bde87556..6fc974200c6f399 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -369,10 +369,9 @@ func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) {
     return
 }
 
-// In this example, emit 2 atomic RMWs.
-//
-// Note, sizeof(%src) = 18 bits. This is modelled as %src_as_bytes:
-// <3xi8> (bits [0, 18) with the input values from %src, and [18, 24) are masked out)
+// Emit two atomic RMW partial stores. Store 6 bits from the input vector (bits [12:18)),
+// into bytes [1:2] from a 3-byte output memref. Due to partial storing,
+// both bytes are accessed partially through masking.
 
 // CHECK-LABEL: func @vector_store_i2_const_index_two_partial_stores(
 // CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)

diff  --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
index 7401e470ed4f2c3..ba2ea40e83d96e4 100644
--- a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
+++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
@@ -99,7 +99,8 @@ struct TestEmulateNarrowTypePass
 
     arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
     memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
-    vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
+    vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns,
+                                                      disableAtomicRMW);
 
     if (failed(applyPartialConversion(op, target, std::move(patterns))))
       signalPassFailure();
@@ -118,6 +119,12 @@ struct TestEmulateNarrowTypePass
       *this, "skip-memref-type-conversion",
       llvm::cl::desc("disable memref type conversion (to test failures)"),
       llvm::cl::init(false)};
+
+  Option<bool> disableAtomicRMW{
+      *this, "disable-atomic-rmw",
+      llvm::cl::desc("disable atomic read-modify-write and prefer generating "
+                     "normal sequence"),
+      llvm::cl::init(false)};
 };
 } // namespace
 


        


More information about the Mlir-commits mailing list