[Mlir-commits] [mlir] [MLIR] Implement emulation of static indexing subbyte type vector stores (PR #115922)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jan 12 17:35:17 PST 2025


https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/115922

>From 331b3846b869afcbee2b4e74ca6b599d06043b8b Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Fri, 10 Jan 2025 10:17:26 +0000
Subject: [PATCH 1/2] [MLIR] atomic emulation of static indexing subbyte type
 vector stores

This patch enables unaligned, statically indexed storing of vectors with sub emulation width element types.

To illustrate the mechanism, consider the example of storing vector<7xi2> into memref<3x7xi2>[1, 0].
In this case the linearized indices of those bits being overwritten are [14, 28), which are:

* the last 2 bits of byte no.2
* byte no.3
* first 4 bits of byte no.4

Because memory accesses are in bytes, byte no.2 and no.4 in the above example are only being modified partially.
In the case of multi-threading scenario, in order to avoid data contention, these two bytes must be handled atomically.
---
 .../Transforms/VectorEmulateNarrowType.cpp    | 286 +++++++++++++++---
 .../vector-emulate-narrow-type-unaligned.mlir | 137 +++++++++
 2 files changed, 378 insertions(+), 45 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index d04f302200519e..38f6ce78b76eae 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -45,6 +45,9 @@ using namespace mlir;
 #define DBGSNL() (llvm::dbgs() << "\n")
 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 
+using VectorValue = TypedValue<VectorType>;
+using MemRefValue = TypedValue<MemRefType>;
+
 /// Returns a compressed mask for the emulated vector. For example, when
 /// emulating an eight-element `i8` vector with `i32` (i.e. when the source
 /// elements span two dest elements), this method compresses `vector<8xi1>`
@@ -194,13 +197,10 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
 /// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
 /// emitting `vector.extract_strided_slice`.
 static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
-                                        VectorType extractType, Value source,
-                                        int64_t frontOffset,
+                                        Value source, int64_t frontOffset,
                                         int64_t subvecSize) {
   auto vectorType = cast<VectorType>(source.getType());
-  assert((vectorType.getRank() == 1 && extractType.getRank() == 1) &&
-         "expected 1-D source and destination types");
-  (void)vectorType;
+  assert(vectorType.getRank() == 1 && "expected 1-D source types");
   assert(frontOffset + subvecSize <= vectorType.getNumElements() &&
          "subvector out of bounds");
 
@@ -211,9 +211,12 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
   auto offsets = rewriter.getI64ArrayAttr({frontOffset});
   auto sizes = rewriter.getI64ArrayAttr({subvecSize});
   auto strides = rewriter.getI64ArrayAttr({1});
+
+  auto resultVectorType =
+      VectorType::get({subvecSize}, vectorType.getElementType());
   return rewriter
-      .create<vector::ExtractStridedSliceOp>(loc, extractType, source, offsets,
-                                             sizes, strides)
+      .create<vector::ExtractStridedSliceOp>(loc, resultVectorType, source,
+                                             offsets, sizes, strides)
       ->getResult(0);
 }
 
@@ -237,8 +240,8 @@ static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc,
 /// function emits multiple `vector.extract` and `vector.insert` ops, so only
 /// use it when `offset` cannot be folded into a constant value.
 static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
-                                         TypedValue<VectorType> source,
-                                         Value dest, OpFoldResult offset,
+                                         VectorValue source, Value dest,
+                                         OpFoldResult offset,
                                          int64_t numElementsToExtract) {
   for (int i = 0; i < numElementsToExtract; ++i) {
     Value extractLoc =
@@ -255,8 +258,8 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
 
 /// Inserts a 1-D subvector into a 1-D `dest` vector at index `destOffsetVar`.
 static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
-                                        TypedValue<VectorType> source,
-                                        Value dest, OpFoldResult destOffsetVar,
+                                        VectorValue source, Value dest,
+                                        OpFoldResult destOffsetVar,
                                         size_t length) {
   assert(length > 0 && "length must be greater than 0");
   Value destOffsetVal =
@@ -277,11 +280,12 @@ static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
 /// specifically, use `emulatedElemType` for loading a vector of `origElemType`.
 /// The load location is given by `base` and `linearizedIndices`, and the
 /// load size is given by `numEmulatedElementsToLoad`.
-static TypedValue<VectorType>
-emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
-                   OpFoldResult linearizedIndices,
-                   int64_t numEmultedElementsToLoad, Type origElemType,
-                   Type emulatedElemType) {
+static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc,
+                                      Value base,
+                                      OpFoldResult linearizedIndices,
+                                      int64_t numEmultedElementsToLoad,
+                                      Type origElemType,
+                                      Type emulatedElemType) {
   auto scale = emulatedElemType.getIntOrFloatBitWidth() /
                origElemType.getIntOrFloatBitWidth();
   auto newLoad = rewriter.create<vector::LoadOp>(
@@ -292,6 +296,89 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
       newLoad);
 }
 
+/// Selects values from two sources based on a mask, and casts the result to a
+/// new type.
+static Value selectAndCast(OpBuilder &builder, Location loc,
+                           VectorType castIntoType, Value mask, Value trueValue,
+                           Value falseValue) {
+  Value maskedValue =
+      builder.create<arith::SelectOp>(loc, mask, trueValue, falseValue);
+  return builder.create<vector::BitCastOp>(loc, castIntoType, maskedValue);
+}
+
+/// Emits `memref.generic_atomic_rmw` op to store a subbyte-sized value to a
+/// byte in memory, with a mask. The `valueToStore` is a vector of subbyte-sized
+/// elements, with size of 8 bits, and the mask is used to select which elements
+/// to store.
+///
+/// Inputs:
+///   linearizedMemref = |2|2|2|2| : <4xi2> (<1xi8>)
+///   linearizedIndex = 2
+///   valueToStore = |3|3|3|3| : vector<4xi2>
+///   mask = |0|0|1|1| : vector<4xi1>
+///
+/// Result:
+///   linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
+static void atomicStore(OpBuilder &builder, Location loc,
+                        MemRefValue linearizedMemref, Value linearizedIndex,
+                        VectorValue valueToStore, Value mask) {
+  assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
+
+  // Create an atomic load-modify-write region using
+  // `memref.generic_atomic_rmw`.
+  auto atomicOp = builder.create<memref::GenericAtomicRMWOp>(
+      loc, linearizedMemref, ValueRange{linearizedIndex});
+  Value origValue = atomicOp.getCurrentValue();
+
+  OpBuilder::InsertionGuard guard(builder);
+  builder.setInsertionPointToStart(atomicOp.getBody());
+
+  // Load the original value from memory, and cast it to the original element
+  // type.
+  auto oneElemVecType = VectorType::get({1}, origValue.getType());
+  Value origVecValue = builder.create<vector::FromElementsOp>(
+      loc, oneElemVecType, ValueRange{origValue});
+  origVecValue = builder.create<vector::BitCastOp>(loc, valueToStore.getType(),
+                                                   origVecValue);
+
+  // Construct the final masked value and yield it.
+  Value maskedValue = selectAndCast(builder, loc, oneElemVecType, mask,
+                                    valueToStore, origVecValue);
+  auto scalarMaskedValue =
+      builder.create<vector::ExtractOp>(loc, maskedValue, 0);
+  builder.create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
+}
+
+/// Extract `sliceNumElements` from source `vector` at `sliceOffset`,
+/// and insert it into an empty vector at offset `byteOffset`.
+/// Inputs:
+///   vector = |1|2|3|4| : vector<4xi2>
+///   sliceOffset = 1
+///   sliceNumElements = 2
+///   byteOffset = 2
+/// Output:
+///   vector = |0|0|2|3| : vector<4xi2>
+static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
+                                  Location loc, VectorValue vector,
+                                  int64_t sliceOffset, int64_t sliceNumElements,
+                                  int64_t byteOffset) {
+  assert(vector.getType().getRank() == 1 && "expected 1-D vector");
+  auto vectorElementType = vector.getType().getElementType();
+  assert(
+      sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&
+      "sliceNumElements * vector element size must be less than or equal to 8");
+  assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
+         "vector element must be a valid sub-byte type");
+  auto scale = 8 / vectorElementType.getIntOrFloatBitWidth();
+  auto emptyByteVector = rewriter.create<arith::ConstantOp>(
+      loc, VectorType::get({scale}, vectorElementType),
+      rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType)));
+  auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
+                                              sliceOffset, sliceNumElements);
+  return staticallyInsertSubvector(rewriter, loc, extracted, emptyByteVector,
+                                   byteOffset);
+}
+
 namespace {
 
 //===----------------------------------------------------------------------===//
@@ -301,6 +388,9 @@ namespace {
 struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
   using OpConversionPattern::OpConversionPattern;
 
+  ConvertVectorStore(MLIRContext *context, bool useAtomicWrites)
+      : OpConversionPattern<vector::StoreOp>(context) {}
+
   LogicalResult
   matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -312,8 +402,9 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
 
     auto loc = op.getLoc();
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
-    Type oldElementType = op.getValueToStore().getType().getElementType();
-    Type newElementType = convertedType.getElementType();
+    auto valueToStore = cast<VectorValue>(op.getValueToStore());
+    auto oldElementType = valueToStore.getType().getElementType();
+    auto newElementType = convertedType.getElementType();
     int srcBits = oldElementType.getIntOrFloatBitWidth();
     int dstBits = newElementType.getIntOrFloatBitWidth();
 
@@ -321,7 +412,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
       return rewriter.notifyMatchFailure(
           op, "only dstBits % srcBits == 0 supported");
     }
-    int scale = dstBits / srcBits;
+    int numSrcElemsPerDest = dstBits / srcBits;
 
     // Adjust the number of elements to store when emulating narrow types.
     // Here only the 1-D vector store is considered, and the N-D memref types
@@ -336,15 +427,15 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     // vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
     // vector<4xi8>
 
-    auto origElements = op.getValueToStore().getType().getNumElements();
-    if (origElements % scale != 0)
-      return failure();
+    auto origElements = valueToStore.getType().getNumElements();
+    bool isUnalignedEmulation = origElements % numSrcElemsPerDest != 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
 
     OpFoldResult linearizedIndices;
-    std::tie(std::ignore, linearizedIndices) =
+    memref::LinearizedMemRefInfo linearizedInfo;
+    std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
             rewriter, loc, srcBits, dstBits,
             stridedMetadata.getConstifiedMixedOffset(),
@@ -352,14 +443,122 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
-    auto numElements = origElements / scale;
-    auto bitCast = rewriter.create<vector::BitCastOp>(
-        loc, VectorType::get(numElements, newElementType),
-        op.getValueToStore());
+    auto foldedNumFrontPadElems =
+        isUnalignedEmulation
+            ? getConstantIntValue(linearizedInfo.intraDataOffset)
+            : 0;
+
+    if (!foldedNumFrontPadElems) {
+      // Unimplemented case for dynamic front padding size != 0
+      return failure();
+    }
+
+    auto linearizedMemref = cast<MemRefValue>(adaptor.getBase());
+
+    // Shortcut: conditions when subbyte store at the front is not needed:
+    // 1. The source vector size is multiple of byte size
+    // 2. The address of the store is aligned to the emulated width boundary
+    if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0) {
+      auto numElements = origElements / numSrcElemsPerDest;
+      auto bitCast = rewriter.create<vector::BitCastOp>(
+          loc, VectorType::get(numElements, newElementType),
+          op.getValueToStore());
+      rewriter.replaceOpWithNewOp<vector::StoreOp>(
+          op, bitCast.getResult(), linearizedMemref,
+          getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+      return success();
+    }
+
+    // The index into the target memref we are storing to
+    Value currentDestIndex =
+        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
+    auto subWidthStoreMaskType =
+        VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type());
+    // The index into the source vector we are currently processing
+    auto currentSourceIndex = 0;
+
+    // 1. Partial width store for the first byte, when the store address is not
+    // aligned to emulated width boundary, deal with the unaligned part so that
+    // the rest elements are aligned to width boundary.
+    auto frontSubWidthStoreElem =
+        (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
+    if (frontSubWidthStoreElem > 0) {
+      SmallVector<bool> frontMaskValues(numSrcElemsPerDest, false);
+      if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
+        std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
+                    origElements, true);
+        frontSubWidthStoreElem = origElements;
+      } else {
+        std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
+                    *foldedNumFrontPadElems, true);
+      }
+      auto frontMask = rewriter.create<arith::ConstantOp>(
+          loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
+
+      currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
+      auto value =
+          extractSliceIntoByte(rewriter, loc, valueToStore, 0,
+                               frontSubWidthStoreElem, *foldedNumFrontPadElems);
 
-    rewriter.replaceOpWithNewOp<vector::StoreOp>(
-        op, bitCast.getResult(), adaptor.getBase(),
-        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+      atomicStore(rewriter, loc, linearizedMemref, currentDestIndex,
+                  cast<VectorValue>(value), frontMask.getResult());
+    }
+
+    if (currentSourceIndex >= origElements) {
+      rewriter.eraseOp(op);
+      return success();
+    }
+
+    // Increment the destination index by 1 to align to the emulated width
+    // boundary.
+    auto constantOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    currentDestIndex = rewriter.create<arith::AddIOp>(
+        loc, rewriter.getIndexType(), currentDestIndex, constantOne);
+
+    // 2. Full width store. After the previous step, the store address is
+    // aligned to the emulated width boundary.
+    int64_t fullWidthStoreSize =
+        (origElements - currentSourceIndex) / numSrcElemsPerDest;
+    int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
+    if (fullWidthStoreSize > 0) {
+      auto fullWidthStorePart = staticallyExtractSubvector(
+          rewriter, loc, valueToStore, currentSourceIndex,
+          numNonFullWidthElements);
+
+      auto originType = cast<VectorType>(fullWidthStorePart.getType());
+      auto memrefElemType = getElementTypeOrSelf(linearizedMemref.getType());
+      auto storeType = VectorType::get(
+          {originType.getNumElements() / numSrcElemsPerDest}, memrefElemType);
+      auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
+                                                        fullWidthStorePart);
+      rewriter.create<vector::StoreOp>(loc, bitCast.getResult(),
+                                       linearizedMemref, currentDestIndex);
+
+      currentSourceIndex += numNonFullWidthElements;
+      currentDestIndex = rewriter.create<arith::AddIOp>(
+          loc, rewriter.getIndexType(), currentDestIndex,
+          rewriter.create<arith::ConstantIndexOp>(loc, fullWidthStoreSize));
+    }
+
+    // 3. Deal with trailing elements that are aligned to the emulated width,
+    // but their length is smaller than the emulated width.
+    auto remainingElements = origElements - currentSourceIndex;
+    if (remainingElements != 0) {
+      auto subWidthStorePart =
+          extractSliceIntoByte(rewriter, loc, cast<VectorValue>(valueToStore),
+                               currentSourceIndex, remainingElements, 0);
+
+      // Generate back mask
+      auto maskValues = SmallVector<bool>(numSrcElemsPerDest, 0);
+      std::fill_n(maskValues.begin(), remainingElements, 1);
+      auto backMask = rewriter.create<arith::ConstantOp>(
+          loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
+
+      atomicStore(rewriter, loc, linearizedMemref, currentDestIndex,
+                  cast<VectorValue>(subWidthStorePart), backMask.getResult());
+    }
+
+    rewriter.eraseOp(op);
     return success();
   }
 };
@@ -564,12 +763,11 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
       auto resultVector = rewriter.create<arith::ConstantOp>(
           loc, op.getType(), rewriter.getZeroAttr(op.getType()));
       result = dynamicallyExtractSubVector(
-          rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
+          rewriter, loc, cast<VectorValue>(result), resultVector,
           linearizedInfo.intraDataOffset, origElements);
     } else if (isUnalignedEmulation) {
-      result =
-          staticallyExtractSubvector(rewriter, loc, op.getType(), result,
-                                     *foldedIntraVectorOffset, origElements);
+      result = staticallyExtractSubvector(
+          rewriter, loc, result, *foldedIntraVectorOffset, origElements);
     }
     rewriter.replaceOp(op, result);
     return success();
@@ -685,8 +883,8 @@ struct ConvertVectorMaskedLoad final
         loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
     if (!foldedIntraVectorOffset) {
       passthru = dynamicallyInsertSubVector(
-          rewriter, loc, dyn_cast<TypedValue<VectorType>>(passthru),
-          emptyVector, linearizedInfo.intraDataOffset, origElements);
+          rewriter, loc, cast<VectorValue>(passthru), emptyVector,
+          linearizedInfo.intraDataOffset, origElements);
     } else if (isUnalignedEmulation) {
       passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
                                            *foldedIntraVectorOffset);
@@ -713,7 +911,7 @@ struct ConvertVectorMaskedLoad final
         loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
     if (!foldedIntraVectorOffset) {
       mask = dynamicallyInsertSubVector(
-          rewriter, loc, dyn_cast<TypedValue<VectorType>>(mask), emptyMask,
+          rewriter, loc, cast<VectorValue>(mask), emptyMask,
           linearizedInfo.intraDataOffset, origElements);
     } else if (isUnalignedEmulation) {
       mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
@@ -724,12 +922,11 @@ struct ConvertVectorMaskedLoad final
         rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
     if (!foldedIntraVectorOffset) {
       result = dynamicallyExtractSubVector(
-          rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
-          op.getPassThru(), linearizedInfo.intraDataOffset, origElements);
+          rewriter, loc, cast<VectorValue>(result), op.getPassThru(),
+          linearizedInfo.intraDataOffset, origElements);
     } else if (isUnalignedEmulation) {
-      result =
-          staticallyExtractSubvector(rewriter, loc, op.getType(), result,
-                                     *foldedIntraVectorOffset, origElements);
+      result = staticallyExtractSubvector(
+          rewriter, loc, result, *foldedIntraVectorOffset, origElements);
     }
     rewriter.replaceOp(op, result);
 
@@ -812,9 +1009,8 @@ struct ConvertVectorTransferRead final
                                            linearizedInfo.intraDataOffset,
                                            origElements);
     } else if (isUnalignedEmulation) {
-      result =
-          staticallyExtractSubvector(rewriter, loc, op.getType(), result,
-                                     *foldedIntraVectorOffset, origElements);
+      result = staticallyExtractSubvector(
+          rewriter, loc, result, *foldedIntraVectorOffset, origElements);
     }
     rewriter.replaceOp(op, result);
 
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 4332e80feed421..b01f9165d9eb74 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -356,3 +356,140 @@ func.func @vector_maskedload_i2_constant_mask_unaligned(%passthru: vector<5xi2>)
 // CHECK: %[[RESULT:.+]] = vector.extract_strided_slice %[[SELECT]]
 // CHECK-SAME: {offsets = [1], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2>
 // CHECK: return %[[RESULT]] : vector<5xi2>
+
+///----------------------------------------------------------------------------------------
+/// vector.store
+///----------------------------------------------------------------------------------------
+
+func.func @vector_store_i2_const_index_two_atomic(%arg0: vector<3xi2>) {
+    %0 = memref.alloc() : memref<3x3xi2>
+    %c0 = arith.constant 0 : index
+    %c2 = arith.constant 2 : index
+    vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
+    return
+}
+
+// In this example, emit 2 atomic RMWs.
+// Load from bit [12:18), byte [1:2] of total 3 bytes, both bytes needs rmw.
+
+// CHECK-LABEL: func @vector_store_i2_const_index_two_atomic(
+// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]> : vector<4xi1>
+// CHECK: %[[CST_0:.+]] = arith.constant dense<0> : vector<4xi2>
+
+// Part 1 atomic RMW sequence
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST_0]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<3xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
+
+// Part 2 atomic RMW sequence
+// CHECK: %[[ADDR2:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT3]], %[[CST_0]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[CST1:.+]] = arith.constant dense<[true, false, false, false]> : vector<4xi1>
+// CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDR2]]] : memref<3xi8> {
+// CHECK: %[[ARG2:.+]]: i8):
+// CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8>
+// CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST4]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST5:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT4:.+]] = vector.extract %[[BITCAST5]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT4]] : i8
+
+// -----
+
+func.func @vector_store_i2_atomic(%arg0: vector<7xi2>) {
+    %0 = memref.alloc() : memref<3x7xi2>
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2>
+    return
+}
+
+// In this example, emit 2 atomic RMWs and 1 non-atomic store:
+// CHECK-LABEL: func @vector_store_i2_atomic(
+// CHECK-SAME: %[[ARG0:.+]]: vector<7xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]> : vector<4xi1>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+
+// First atomic RMW:
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]} : vector<7xi2> to vector<1xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [3], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<6xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
+
+// Non-atomic store:
+// CHECK: %[[ADDR:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]} : vector<7xi2> to vector<4xi2>
+// CHECK: %[[BITCAST3:.+]] = vector.bitcast %[[EXTRACT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: vector.store %[[BITCAST3]], %[[ALLOC]][%[[ADDR]]] : memref<6xi8>, vector<1xi8>
+
+// Second atomic RMW:
+// CHECK: %[[ADDR2:.+]] = arith.addi %[[ADDR]], %[[C1]] : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]} : vector<7xi2> to vector<2xi2>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT3]], %[[CST0]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[CST1:.+]] = arith.constant dense<[true, true, false, false]> : vector<4xi1> 
+// CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDR2]]] : memref<6xi8> {
+// CHECK: %[[ARG2:.+]]: i8):
+// CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8>
+// CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST4]] :
+// CHECK-SAME: vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST5:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT4:.+]] = vector.extract %[[BITCAST5]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT4]] : i8    
+
+// -----
+
+func.func @vector_store_i2_single_atomic(%arg0: vector<1xi2>) {
+    %0 = memref.alloc() : memref<4x1xi2>
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2>
+    return
+}
+
+// In this example, only emit 1 atomic store
+// CHECK-LABEL: func @vector_store_i2_single_atomic(
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, true, false, false]> : vector<4xi1>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]]
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2>
+
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C0]]] : memref<1xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8

>From cb67d62ab47788b34a50b0cad5a413754b9fa2a2 Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Sat, 11 Jan 2025 09:19:39 +0000
Subject: [PATCH 2/2] Update according comments

---
 .../Transforms/VectorEmulateNarrowType.cpp    | 65 +++++++++++--------
 .../vector-emulate-narrow-type-unaligned.mlir | 24 +++----
 2 files changed, 51 insertions(+), 38 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 38f6ce78b76eae..1f0d7bbede491d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -296,38 +296,49 @@ static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc,
       newLoad);
 }
 
-/// Selects values from two sources based on a mask, and casts the result to a
-/// new type.
-static Value selectAndCast(OpBuilder &builder, Location loc,
-                           VectorType castIntoType, Value mask, Value trueValue,
-                           Value falseValue) {
-  Value maskedValue =
+/// Downcast two values to `downcastType`, then select values
+/// based on `mask`, and casts the result to `upcastType`.
+static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc,
+                                     VectorType downcastType,
+                                     VectorType upcastType, Value mask,
+                                     Value trueValue, Value falseValue) {
+  assert(
+      downcastType.getNumElements() * downcastType.getElementTypeBitWidth() ==
+          upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
+      "expected upcastType size to be twice the size of downcastType");
+  if (trueValue.getType() != downcastType)
+    trueValue = builder.create<vector::BitCastOp>(loc, downcastType, trueValue);
+  if (falseValue.getType() != downcastType)
+    falseValue =
+        builder.create<vector::BitCastOp>(loc, downcastType, falseValue);
+  Value selectedType =
       builder.create<arith::SelectOp>(loc, mask, trueValue, falseValue);
-  return builder.create<vector::BitCastOp>(loc, castIntoType, maskedValue);
+  // Upcast the selected value to the new type.
+  return builder.create<vector::BitCastOp>(loc, upcastType, selectedType);
 }
 
 /// Emits `memref.generic_atomic_rmw` op to store a subbyte-sized value to a
-/// byte in memory, with a mask. The `valueToStore` is a vector of subbyte-sized
-/// elements, with size of 8 bits, and the mask is used to select which elements
-/// to store.
+/// byte in `linearizedMemref`, with a mask. The `valueToStore` is a vector of
+/// subbyte-sized elements, with size of 8 bits, and the mask is used to select
+/// which elements to store.
 ///
 /// Inputs:
 ///   linearizedMemref = |2|2|2|2| : <4xi2> (<1xi8>)
-///   linearizedIndex = 2
+///   storeIdx = 2
 ///   valueToStore = |3|3|3|3| : vector<4xi2>
 ///   mask = |0|0|1|1| : vector<4xi1>
 ///
 /// Result:
 ///   linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
 static void atomicStore(OpBuilder &builder, Location loc,
-                        MemRefValue linearizedMemref, Value linearizedIndex,
+                        MemRefValue linearizedMemref, Value storeIdx,
                         VectorValue valueToStore, Value mask) {
   assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
 
   // Create an atomic load-modify-write region using
   // `memref.generic_atomic_rmw`.
   auto atomicOp = builder.create<memref::GenericAtomicRMWOp>(
-      loc, linearizedMemref, ValueRange{linearizedIndex});
+      loc, linearizedMemref, ValueRange{storeIdx});
   Value origValue = atomicOp.getCurrentValue();
 
   OpBuilder::InsertionGuard guard(builder);
@@ -338,30 +349,30 @@ static void atomicStore(OpBuilder &builder, Location loc,
   auto oneElemVecType = VectorType::get({1}, origValue.getType());
   Value origVecValue = builder.create<vector::FromElementsOp>(
       loc, oneElemVecType, ValueRange{origValue});
-  origVecValue = builder.create<vector::BitCastOp>(loc, valueToStore.getType(),
-                                                   origVecValue);
 
   // Construct the final masked value and yield it.
-  Value maskedValue = selectAndCast(builder, loc, oneElemVecType, mask,
-                                    valueToStore, origVecValue);
+  Value maskedValue =
+      downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
+                              oneElemVecType, mask, valueToStore, origVecValue);
   auto scalarMaskedValue =
       builder.create<vector::ExtractOp>(loc, maskedValue, 0);
   builder.create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
 }
 
-/// Extract `sliceNumElements` from source `vector` at `sliceOffset`,
-/// and insert it into an empty vector at offset `byteOffset`.
+/// Extract `sliceNumElements` from source `vector` at `extractOffset`,
+/// and insert it into an empty vector at `insertOffset`.
 /// Inputs:
-///   vector = |1|2|3|4| : vector<4xi2>
-///   sliceOffset = 1
+///   vec_in  = |0|1|2|3| : vector<4xi2>
+///   extractOffset = 1
 ///   sliceNumElements = 2
-///   byteOffset = 2
+///   insertOffset = 2
 /// Output:
-///   vector = |0|0|2|3| : vector<4xi2>
+///   vec_out = |0|0|1|2| : vector<4xi2>
 static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
                                   Location loc, VectorValue vector,
-                                  int64_t sliceOffset, int64_t sliceNumElements,
-                                  int64_t byteOffset) {
+                                  int64_t extractOffset,
+                                  int64_t sliceNumElements,
+                                  int64_t insertOffset) {
   assert(vector.getType().getRank() == 1 && "expected 1-D vector");
   auto vectorElementType = vector.getType().getElementType();
   assert(
@@ -374,9 +385,9 @@ static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
       loc, VectorType::get({scale}, vectorElementType),
       rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType)));
   auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
-                                              sliceOffset, sliceNumElements);
+                                              extractOffset, sliceNumElements);
   return staticallyInsertSubvector(rewriter, loc, extracted, emptyByteVector,
-                                   byteOffset);
+                                   insertOffset);
 }
 
 namespace {
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index b01f9165d9eb74..a80ab7b7e4166e 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -361,25 +361,27 @@ func.func @vector_maskedload_i2_constant_mask_unaligned(%passthru: vector<5xi2>)
 /// vector.store
 ///----------------------------------------------------------------------------------------
 
-func.func @vector_store_i2_const_index_two_atomic(%arg0: vector<3xi2>) {
-    %0 = memref.alloc() : memref<3x3xi2>
+func.func @vector_store_i2_const_index_two_atomic_rmw(%arg0: vector<3xi2>) {
+    %src = memref.alloc() : memref<3x3xi2>
     %c0 = arith.constant 0 : index
     %c2 = arith.constant 2 : index
-    vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
+    vector.store %arg0, %src[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
     return
 }
 
 // In this example, emit 2 atomic RMWs.
-// Load from bit [12:18), byte [1:2] of total 3 bytes, both bytes needs rmw.
+//
+// Note, sizeof(%src) = 18 bits. This is modelled as %src_as_bytes:
+// <3xi8> (bits [0, 18) with the input values from %src, and [18, 24) are masked out)
 
-// CHECK-LABEL: func @vector_store_i2_const_index_two_atomic(
+// CHECK-LABEL: func @vector_store_i2_const_index_two_atomic_rmw(
 // CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
 // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
 // CHECK: %[[C1:.+]] = arith.constant 1 : index
 // CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]> : vector<4xi1>
 // CHECK: %[[CST_0:.+]] = arith.constant dense<0> : vector<4xi2>
 
-// Part 1 atomic RMW sequence
+// Part 1 atomic RMW sequence (load bits [12, 16) from %src_as_bytes[1])
 // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
 // CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
 // CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST_0]]
@@ -393,7 +395,7 @@ func.func @vector_store_i2_const_index_two_atomic(%arg0: vector<3xi2>) {
 // CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
 // CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
 
-// Part 2 atomic RMW sequence
+// Part 2 atomic RMW sequence (load bits [16, 18) from %src_as_bytes[2])
 // CHECK: %[[ADDR2:.+]] = arith.addi %[[C1]], %[[C1]] : index
 // CHECK: %[[EXTRACT3:.+]] = vector.extract_strided_slice %[[ARG0]]
 // CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
@@ -411,7 +413,7 @@ func.func @vector_store_i2_const_index_two_atomic(%arg0: vector<3xi2>) {
 
 // -----
 
-func.func @vector_store_i2_atomic(%arg0: vector<7xi2>) {
+func.func @vector_store_i2_atomic_rmw(%arg0: vector<7xi2>) {
     %0 = memref.alloc() : memref<3x7xi2>
     %c0 = arith.constant 0 : index
     %c1 = arith.constant 1 : index
@@ -420,7 +422,7 @@ func.func @vector_store_i2_atomic(%arg0: vector<7xi2>) {
 }
 
 // In this example, emit 2 atomic RMWs and 1 non-atomic store:
-// CHECK-LABEL: func @vector_store_i2_atomic(
+// CHECK-LABEL: func @vector_store_i2_atomic_rmw(
 // CHECK-SAME: %[[ARG0:.+]]: vector<7xi2>)
 // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8>
 // CHECK: %[[C1:.+]] = arith.constant 1 : index
@@ -467,7 +469,7 @@ func.func @vector_store_i2_atomic(%arg0: vector<7xi2>) {
 
 // -----
 
-func.func @vector_store_i2_single_atomic(%arg0: vector<1xi2>) {
+func.func @vector_store_i2_const_index_one_atomic_rmw(%arg0: vector<1xi2>) {
     %0 = memref.alloc() : memref<4x1xi2>
     %c0 = arith.constant 0 : index
     %c1 = arith.constant 1 : index
@@ -476,7 +478,7 @@ func.func @vector_store_i2_single_atomic(%arg0: vector<1xi2>) {
 }
 
 // In this example, only emit 1 atomic store
-// CHECK-LABEL: func @vector_store_i2_single_atomic(
+// CHECK-LABEL: func @vector_store_i2_const_index_one_atomic_rmw(
 // CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>)
 // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8>
 // CHECK: %[[C0:.+]] = arith.constant 0 : index



More information about the Mlir-commits mailing list