[Mlir-commits] [mlir] [MLIR] Implement emulation of static indexing subbyte type vector stores (PR #115922)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 12 13:59:23 PST 2024


https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/115922

>From bfdbed21b5180bec317d8f1ec70b5f931e154509 Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Fri, 25 Oct 2024 15:19:42 +0000
Subject: [PATCH] Implement vector stores

---
 .../Transforms/VectorEmulateNarrowType.cpp    | 211 +++++++++++++++---
 .../vector-emulate-narrow-type-unaligned.mlir | 135 +++++++++++
 2 files changed, 313 insertions(+), 33 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index bb0731d768dfa7..75e6cc3f11f507 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -33,6 +33,7 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
 #include <cstdint>
@@ -143,19 +144,19 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
 /// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
 /// emitting `vector.extract_strided_slice`.
 static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
-                                        VectorType extractType, Value source,
-                                        int64_t frontOffset,
+                                        Value source, int64_t frontOffset,
                                         int64_t subvecSize) {
-  auto vectorType = cast<VectorType>(source.getType());
-  assert((vectorType.getRank() == 1 && extractType.getRank() == 1) &&
-         "expected 1-D source and destination types");
-  (void)vectorType;
+  auto vectorType = llvm::cast<VectorType>(source.getType());
+  assert(vectorType.getRank() == 1 && "expected 1-D source types");
   auto offsets = rewriter.getI64ArrayAttr({frontOffset});
   auto sizes = rewriter.getI64ArrayAttr({subvecSize});
   auto strides = rewriter.getI64ArrayAttr({1});
+
+  auto resultVectorType =
+      VectorType::get({subvecSize}, vectorType.getElementType());
   return rewriter
-      .create<vector::ExtractStridedSliceOp>(loc, extractType, source, offsets,
-                                             sizes, strides)
+      .create<vector::ExtractStridedSliceOp>(loc, resultVectorType, source,
+                                             offsets, sizes, strides)
       ->getResult(0);
 }
 
@@ -164,12 +165,10 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
 /// `vector.insert_strided_slice`.
 static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc,
                                        Value src, Value dest, int64_t offset) {
-  auto srcType = cast<VectorType>(src.getType());
-  auto destType = cast<VectorType>(dest.getType());
+  [[maybe_unused]] auto srcType = cast<VectorType>(src.getType());
+  [[maybe_unused]] auto destType = cast<VectorType>(dest.getType());
   assert(srcType.getRank() == 1 && destType.getRank() == 1 &&
          "expected source and dest to be vector type");
-  (void)srcType;
-  (void)destType;
   auto offsets = rewriter.getI64ArrayAttr({offset});
   auto strides = rewriter.getI64ArrayAttr({1});
   return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
@@ -236,6 +235,63 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
       newLoad);
 }
 
+static void nonAtomicStore(ConversionPatternRewriter &rewriter, Location loc,
+                           Value memref, Value index, Value value) {
+  auto originType = dyn_cast<VectorType>(value.getType());
+  auto memrefElemType = dyn_cast<MemRefType>(memref.getType()).getElementType();
+  auto scale = memrefElemType.getIntOrFloatBitWidth() /
+               originType.getElementType().getIntOrFloatBitWidth();
+  auto storeType =
+      VectorType::get({originType.getNumElements() / scale}, memrefElemType);
+  auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType, value);
+  rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memref, index);
+}
+
+/// atomically store a subbyte-sized value to memory, with a mask.
+static Value atomicStore(OpBuilder &rewriter, Location loc,
+                         Value emulatedMemref, Value emulatedIndex,
+                         TypedValue<VectorType> value, Value mask,
+                         int64_t scale) {
+  auto atomicOp = rewriter.create<memref::GenericAtomicRMWOp>(
+      loc, emulatedMemref, ValueRange{emulatedIndex});
+  OpBuilder builder =
+      OpBuilder::atBlockEnd(atomicOp.getBody(), rewriter.getListener());
+  Value origValue = atomicOp.getCurrentValue();
+
+  // i8 -> vector type <1xi8> then <1xi8> -> <scale x i.>
+  auto oneVectorType = VectorType::get({1}, origValue.getType());
+  auto fromElem = builder.create<vector::FromElementsOp>(loc, oneVectorType,
+                                                         ValueRange{origValue});
+  auto vectorBitCast =
+      builder.create<vector::BitCastOp>(loc, value.getType(), fromElem);
+
+  auto select =
+      builder.create<arith::SelectOp>(loc, mask, value, vectorBitCast);
+  auto bitcast2 = builder.create<vector::BitCastOp>(loc, oneVectorType, select);
+  auto extract = builder.create<vector::ExtractOp>(loc, bitcast2, 0);
+  builder.create<memref::AtomicYieldOp>(loc, extract.getResult());
+  return atomicOp;
+}
+
+// Extract a slice of a vector, and insert it into a byte vector.
+static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
+                                  Location loc, TypedValue<VectorType> vector,
+                                  int64_t sliceOffset, int64_t sliceNumElements,
+                                  int64_t byteOffset) {
+  auto vectorElementType = vector.getType().getElementType();
+  assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
+         "vector element must be a valid sub-byte type");
+  auto scale = 8 / vectorElementType.getIntOrFloatBitWidth();
+  auto emptyByteVector = rewriter.create<arith::ConstantOp>(
+      loc, VectorType::get({scale}, vectorElementType),
+      rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType)));
+  auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
+                                              sliceOffset, sliceNumElements);
+  auto inserted = staticallyInsertSubvector(rewriter, loc, extracted,
+                                            emptyByteVector, byteOffset);
+  return inserted;
+}
+
 namespace {
 
 //===----------------------------------------------------------------------===//
@@ -256,7 +312,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
 
     auto loc = op.getLoc();
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
-    Type oldElementType = op.getValueToStore().getType().getElementType();
+    auto valueToStore = op.getValueToStore();
+    Type oldElementType = valueToStore.getType().getElementType();
     Type newElementType = convertedType.getElementType();
     int srcBits = oldElementType.getIntOrFloatBitWidth();
     int dstBits = newElementType.getIntOrFloatBitWidth();
@@ -280,15 +337,15 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     // vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
     // vector<4xi8>
 
-    auto origElements = op.getValueToStore().getType().getNumElements();
-    if (origElements % scale != 0)
-      return failure();
+    auto origElements = valueToStore.getType().getNumElements();
+    bool isUnalignedEmulation = origElements % scale != 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
 
     OpFoldResult linearizedIndices;
-    std::tie(std::ignore, linearizedIndices) =
+    memref::LinearizedMemRefInfo linearizedInfo;
+    std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
             rewriter, loc, srcBits, dstBits,
             stridedMetadata.getConstifiedMixedOffset(),
@@ -296,14 +353,105 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
-    auto numElements = origElements / scale;
-    auto bitCast = rewriter.create<vector::BitCastOp>(
-        loc, VectorType::get(numElements, newElementType),
-        op.getValueToStore());
+    auto foldedIntraVectorOffset =
+        isUnalignedEmulation
+            ? getConstantIntValue(linearizedInfo.intraDataOffset)
+            : 0;
+
+    if (!foldedIntraVectorOffset) {
+      // unimplemented case for dynamic front padding size
+      return failure();
+    }
+
+    if (!isUnalignedEmulation) {
+      auto numElements = origElements / scale;
+      auto bitCast = rewriter.create<vector::BitCastOp>(
+          loc, VectorType::get(numElements, newElementType),
+          op.getValueToStore());
+      rewriter.replaceOpWithNewOp<vector::StoreOp>(
+          op, bitCast.getResult(), adaptor.getBase(),
+          getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+      return llvm::success();
+    }
+
+    Value emulatedMemref = adaptor.getBase();
+    // the index into the target memref we are storing to
+    Value currentDestIndex =
+        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
+    auto constantOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    auto atomicMaskType = VectorType::get({scale}, rewriter.getI1Type());
+    // the index into the source vector we are currently processing
+    auto currentSourceIndex = 0;
+
+    // 1. atomic store for the first byte
+    auto frontAtomicStoreElem = (scale - *foldedIntraVectorOffset) % scale;
+    if (frontAtomicStoreElem != 0) {
+      auto frontMaskValues = llvm::SmallVector<bool>(scale, false);
+      if (*foldedIntraVectorOffset + origElements < scale) {
+        std::fill_n(frontMaskValues.begin() + *foldedIntraVectorOffset,
+                    origElements, true);
+        frontAtomicStoreElem = origElements;
+      } else {
+        std::fill_n(frontMaskValues.end() - frontAtomicStoreElem,
+                    *foldedIntraVectorOffset, true);
+      }
+      auto frontMask = rewriter.create<arith::ConstantOp>(
+          loc, DenseElementsAttr::get(atomicMaskType, frontMaskValues));
+
+      currentSourceIndex = scale - (*foldedIntraVectorOffset);
+      auto value = extractSliceIntoByte(
+          rewriter, loc, cast<TypedValue<VectorType>>(valueToStore), 0,
+          frontAtomicStoreElem, *foldedIntraVectorOffset);
+
+      atomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
+                  cast<TypedValue<VectorType>>(value), frontMask.getResult(),
+                  scale);
+
+      currentDestIndex = rewriter.create<arith::AddIOp>(
+          loc, rewriter.getIndexType(), currentDestIndex, constantOne);
+    }
+
+    if (currentSourceIndex >= origElements) {
+      rewriter.eraseOp(op);
+      return success();
+    }
+
+    // 2. non-atomic store
+    int64_t nonAtomicStoreSize = (origElements - currentSourceIndex) / scale;
+    int64_t numNonAtomicElements = nonAtomicStoreSize * scale;
+    if (nonAtomicStoreSize != 0) {
+      auto nonAtomicStorePart = staticallyExtractSubvector(
+          rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
+          currentSourceIndex, numNonAtomicElements);
+
+      nonAtomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
+                     nonAtomicStorePart);
+
+      currentSourceIndex += numNonAtomicElements;
+      currentDestIndex = rewriter.create<arith::AddIOp>(
+          loc, rewriter.getIndexType(), currentDestIndex,
+          rewriter.create<arith::ConstantIndexOp>(loc, nonAtomicStoreSize));
+    }
+
+    // 3. atomic store for the last byte
+    auto remainingElements = origElements - currentSourceIndex;
+    if (remainingElements != 0) {
+      auto atomicStorePart = extractSliceIntoByte(
+          rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
+          currentSourceIndex, remainingElements, 0);
+
+      // back mask
+      auto maskValues = llvm::SmallVector<bool>(scale, 0);
+      std::fill_n(maskValues.begin(), remainingElements, 1);
+      auto backMask = rewriter.create<arith::ConstantOp>(
+          loc, DenseElementsAttr::get(atomicMaskType, maskValues));
+
+      atomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
+                  cast<TypedValue<VectorType>>(atomicStorePart),
+                  backMask.getResult(), scale);
+    }
 
-    rewriter.replaceOpWithNewOp<vector::StoreOp>(
-        op, bitCast.getResult(), adaptor.getBase(),
-        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+    rewriter.eraseOp(op);
     return success();
   }
 };
@@ -511,9 +659,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
           rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
           linearizedInfo.intraDataOffset, origElements);
     } else if (isUnalignedEmulation) {
-      result =
-          staticallyExtractSubvector(rewriter, loc, op.getType(), result,
-                                     *foldedIntraVectorOffset, origElements);
+      result = staticallyExtractSubvector(
+          rewriter, loc, result, *foldedIntraVectorOffset, origElements);
     }
     rewriter.replaceOp(op, result);
     return success();
@@ -672,9 +819,8 @@ struct ConvertVectorMaskedLoad final
           rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
           op.getPassThru(), linearizedInfo.intraDataOffset, origElements);
     } else if (isUnalignedEmulation) {
-      result =
-          staticallyExtractSubvector(rewriter, loc, op.getType(), result,
-                                     *foldedIntraVectorOffset, origElements);
+      result = staticallyExtractSubvector(
+          rewriter, loc, result, *foldedIntraVectorOffset, origElements);
     }
     rewriter.replaceOp(op, result);
 
@@ -757,9 +903,8 @@ struct ConvertVectorTransferRead final
                                            linearizedInfo.intraDataOffset,
                                            origElements);
     } else if (isUnalignedEmulation) {
-      result =
-          staticallyExtractSubvector(rewriter, loc, op.getType(), result,
-                                     *foldedIntraVectorOffset, origElements);
+      result = staticallyExtractSubvector(
+          rewriter, loc, result, *foldedIntraVectorOffset, origElements);
     }
     rewriter.replaceOp(op, result);
 
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 7ed75ff7f1579c..3b1f5b2e160fe0 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -249,3 +249,138 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
 // CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2>
 // CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2>
 // CHECK: %[[IN9:.+]] = vector.insert %[[EX9]], %[[IN8]] [2] : i2 into vector<3xi2>
+
+// -----
+
+func.func @vector_store_i2_const(%arg0: vector<3xi2>) {
+    %0 = memref.alloc() : memref<3x3xi2>
+    %c0 = arith.constant 0 : index
+    %c2 = arith.constant 2 : index
+    vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
+    return
+}
+
+// in this example, emit 2 atomic stores, with the first storing 1 element and the second storing 2 elements.
+// CHECK: func @vector_store_i2_const(
+// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+
+// atomic store of the first byte
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]> : vector<4xi1>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<3xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
+
+// atomic store of the second byte
+// CHECK: %[[ADDI:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]], %[[CST0]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDI]]] : memref<3xi8> {
+// CHECK: %[[ARG2:.+]]: i8):
+// CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8>
+// CHECK: %[[BITCAST3:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST3]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST4]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT3]] : i8
+
+// -----
+
+func.func @vector_store_i8_2(%arg0: vector<7xi2>) {
+    %0 = memref.alloc() : memref<3x7xi2>
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2>
+    return
+}
+
+// in this example, emit 2 atomic stores and 1 non-atomic store
+
+// CHECK: func @vector_store_i8_2(
+// CHECK-SAME: %[[ARG0:.+]]: vector<7xi2>)
+// CHECK: %[[ALLOC]] = memref.alloc() : memref<6xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]> : vector<4xi1>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+
+// first atomic store
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]} : vector<7xi2> to vector<1xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [3], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<6xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
+
+// non atomic store part
+// CHECK: %[[ADDR:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]} : vector<7xi2> to vector<4xi2>
+// CHECK: %[[BITCAST3:.+]] = vector.bitcast %[[EXTRACT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: vector.store %[[BITCAST3]], %[[ALLOC]][%[[ADDR]]] : memref<6xi8>, vector<1xi8>
+
+// second atomic store
+// CHECK: %[[ADDR2:.+]] = arith.addi %[[ADDR]], %[[C1]] : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]} : vector<7xi2> to vector<2xi2>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT3]], %[[CST0]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDR2]]] : memref<6xi8> {
+// CHECK: %[[ARG2:.+]]: i8):
+// CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8>
+// CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST4]] :
+// CHECK-SAME: vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST5:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT4:.+]] = vector.extract %[[BITCAST5]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT4]] : i8    
+
+// -----
+
+func.func @vector_store_i2_single_atomic(%arg0: vector<1xi2>) {
+    %0 = memref.alloc() : memref<4x1xi2>
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2>
+    return
+}
+
+// in this example, only emit 1 atomic store
+// CHECK: func @vector_store_i2_single_atomic(
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, true, false, false]> : vector<4xi1>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]} : vector<1xi2> to vector<1xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2>
+
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C0]]] : memref<1xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8



More information about the Mlir-commits mailing list