[Mlir-commits] [mlir] [mlir] Add subbyte emulation support for `memref.store`. (PR #72004)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 22 13:45:18 PST 2023


https://github.com/Max191 updated https://github.com/llvm/llvm-project/pull/72004

>From f1c15b5850998c39e1ab6438e7c9a4ec4d581137 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 10 Nov 2023 17:06:43 -0500
Subject: [PATCH] [mlir] Add subbyte emulation support for `memref.store`. This
 adds a conversion for narrow type emulation of memref.store ops. The
 conversion replaces the memref.store with two memref.atomic_rmw ops. Atomics
 are used to prevent race conditions on same-byte accesses, in the event that
 two threads are storing into the same byte.

---
 .../MemRef/Transforms/EmulateNarrowType.cpp   | 164 ++++++++++++++---
 .../Dialect/MemRef/emulate-narrow-type.mlir   | 169 ++++++++++++++++++
 2 files changed, 308 insertions(+), 25 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 9f58e9055acadbb..553af2adb60f3d6 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -17,6 +17,9 @@
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -29,6 +32,26 @@ using namespace mlir;
 // Utility functions
 //===----------------------------------------------------------------------===//
 
+/// Replaces the memref::StoreOp with two new memref::AtomicRMWOps. The first
+/// memref::AtomicRMWOp sets the destination bits to all zero to prepare the
+/// destination byte to be written to. The second memref::AtomicRMWOp does the
+/// writing of the value to store, using an `ori` type operation. The value
+/// to store and the write mask should both have the destination type bitwidth,
+/// and the bits of the value to store should be all zero except for the bits
+/// aligned with the store destination.
+static void replaceStoreWithAtomics(ConversionPatternRewriter &rewriter,
+                                    memref::StoreOp op, Value writeMask,
+                                    Value storeVal, Value memref,
+                                    ValueRange storeIndices) {
+  // Clear destination bits
+  rewriter.create<memref::AtomicRMWOp>(op.getLoc(), arith::AtomicRMWKind::andi,
+                                       writeMask, memref, storeIndices);
+  // Write srcs bits to destination
+  rewriter.create<memref::AtomicRMWOp>(op->getLoc(), arith::AtomicRMWKind::ori,
+                                       storeVal, memref, storeIndices);
+  rewriter.eraseOp(op);
+}
+
 /// When data is loaded/stored in `targetBits` granularity, but is used in
 /// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
 /// treated as an array of elements of width `sourceBits`.
@@ -43,13 +66,67 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
   AffineExpr s0;
   bindSymbols(builder.getContext(), s0);
   int scaleFactor = targetBits / sourceBits;
-  OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply(
-      builder, loc, (s0 % scaleFactor) * sourceBits, {srcIdx});
+  AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
+  OpFoldResult offsetVal =
+      affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx});
   Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
   IntegerType dstType = builder.getIntegerType(targetBits);
   return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
 }
 
+/// When writing a subbyte size, writing needs to happen atomically in case of
+/// another write happening on the same byte at the same time. To do the write,
+/// we first must clear `dstBits` at the `linearizedIndices` of the subbyte
+/// store. This function returns the appropriate mask for clearing these bits.
+static Value getAtomicWriteMask(Location loc, OpFoldResult linearizedIndices,
+                                int64_t srcBits, int64_t dstBits,
+                                Value bitwidthOffset, OpBuilder &builder) {
+  auto dstIntegerType = builder.getIntegerType(dstBits);
+  auto maskRightAlignedAttr =
+      builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
+  Value maskRightAligned =
+      builder
+          .create<arith::ConstantOp>(loc, dstIntegerType, maskRightAlignedAttr)
+          .getResult();
+  Value writeMaskInverse =
+      builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
+  auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
+  Value flipVal =
+      builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr)
+          .getResult();
+  return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
+}
+
+/// Returns the scaled linearized index based on the `srcBits` and `dstBits`
+/// sizes. The input `linearizedIndex` has the grandularity of `srcBits`, and
+/// the returned index has the granularity of `dstBits`
+static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc,
+                                      OpFoldResult linearizedIndex,
+                                      int64_t srcBits, int64_t dstBits) {
+  AffineExpr s0;
+  bindSymbols(builder.getContext(), s0);
+  int64_t scaler = dstBits / srcBits;
+  OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply(
+      builder, loc, s0.floorDiv(scaler), {linearizedIndex});
+  return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices);
+}
+
+static OpFoldResult
+getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
+                        const SmallVector<OpFoldResult> &indices,
+                        Value memref) {
+  auto stridedMetadata =
+      builder.create<memref::ExtractStridedMetadataOp>(loc, memref);
+  OpFoldResult linearizedIndices;
+  std::tie(std::ignore, linearizedIndices) =
+      memref::getLinearizedMemRefOffsetAndSize(
+          builder, loc, srcBits, srcBits,
+          stridedMetadata.getConstifiedMixedOffset(),
+          stridedMetadata.getConstifiedMixedSizes(),
+          stridedMetadata.getConstifiedMixedStrides(), indices);
+  return linearizedIndices;
+}
+
 namespace {
 
 //===----------------------------------------------------------------------===//
@@ -155,32 +232,15 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
       bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
                                                  ValueRange{});
     } else {
-      SmallVector<OpFoldResult> indices =
-          getAsOpFoldResult(adaptor.getIndices());
-
-      auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
-          loc, op.getMemRef());
-
       // Linearize the indices of the original load instruction. Do not account
       // for the scaling yet. This will be accounted for later.
-      OpFoldResult linearizedIndices;
-      std::tie(std::ignore, linearizedIndices) =
-          memref::getLinearizedMemRefOffsetAndSize(
-              rewriter, loc, srcBits, srcBits,
-              stridedMetadata.getConstifiedMixedOffset(),
-              stridedMetadata.getConstifiedMixedSizes(),
-              stridedMetadata.getConstifiedMixedStrides(), indices);
-
-      AffineExpr s0;
-      bindSymbols(rewriter.getContext(), s0);
-      int64_t scaler = dstBits / srcBits;
-      OpFoldResult scaledLinearizedIndices =
-          affine::makeComposedFoldedAffineApply(
-              rewriter, loc, s0.floorDiv(scaler), {linearizedIndices});
+      OpFoldResult linearizedIndices = getLinearizedSrcIndices(
+          rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
+
       Value newLoad = rewriter.create<memref::LoadOp>(
           loc, adaptor.getMemref(),
-          getValueOrCreateConstantIndexOp(rewriter, loc,
-                                          scaledLinearizedIndices));
+          getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
+                                   dstBits));
 
       // Get the offset and shift the bits to the rightmost.
       // Note, currently only the big-endian is supported.
@@ -211,6 +271,60 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ConvertMemrefStore
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
+    int srcBits = op.getMemRefType().getElementTypeBitWidth();
+    int dstBits = convertedType.getElementTypeBitWidth();
+    auto dstIntegerType = rewriter.getIntegerType(dstBits);
+    if (dstBits % srcBits != 0) {
+      return rewriter.notifyMatchFailure(
+          op, "only dstBits % srcBits == 0 supported");
+    }
+
+    Location loc = op.getLoc();
+    Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
+                                                          adaptor.getValue());
+
+    // Special case 0-rank memref stores. We can compute the mask at compile
+    // time.
+    if (convertedType.getRank() == 0) {
+      // Create mask to clear destination bits
+      auto writeMaskValAttr =
+          rewriter.getIntegerAttr(dstIntegerType, ~(1 << (srcBits)) - 1);
+      Value writeMask = rewriter.create<arith::ConstantOp>(loc, dstIntegerType,
+                                                           writeMaskValAttr);
+
+      replaceStoreWithAtomics(rewriter, op, writeMask, extendedInput,
+                              adaptor.getMemref(), ValueRange{});
+      return success();
+    }
+
+    OpFoldResult linearizedIndices = getLinearizedSrcIndices(
+        rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
+    Value storeIndices = getIndicesForLoadOrStore(
+        rewriter, loc, linearizedIndices, srcBits, dstBits);
+    Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits,
+                                                dstBits, rewriter);
+    Value writeMask = getAtomicWriteMask(loc, linearizedIndices, srcBits,
+                                         dstBits, bitwidthOffset, rewriter);
+    // Align the value to write with the destination bits
+    Value alignedVal =
+        rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset);
+    replaceStoreWithAtomics(rewriter, op, writeMask, alignedVal,
+                            adaptor.getMemref(), storeIndices);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ConvertMemRefSubview
 //===----------------------------------------------------------------------===//
@@ -292,7 +406,7 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
 
   // Populate `memref.*` conversion patterns.
   patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
-               ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
+               ConvertMemRefAssumeAlignment, ConvertMemRefSubview, ConvertMemrefStore>(
       typeConverter, patterns.getContext());
   memref::populateResolveExtractStridedMetadataPatterns(patterns);
 }
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 6ed97f05aa7cff2..22c5947fd2ac97b 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -174,3 +174,172 @@ func.func @memref_strided_i4(%idx : index) -> i4 {
 //       CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<16xi32>
 //       CHECK32:   %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][4] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: 4>>
 //       CHECK32:   %[[LOAD:.+]] = memref.load %[[SUBVIEW]]
+
+// -----
+
+func.func @memref_store_i4(%arg0: index, %arg1: i4) -> () {
+    %0 = memref.alloc() : memref<5xi4>
+    memref.store %arg1, %0[%arg0] : memref<5xi4>
+    return
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)>
+//      CHECK: func @memref_store_i4(
+// CHECK-SAME:     %[[ARG0:.+]]: index, %[[ARG1:.+]]: i4
+//  CHECK-DAG:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+//  CHECK-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG1]] : i4 to i8
+//  CHECK-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//  CHECK-DAG:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+//  CHECK-DAG:   %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+//  CHECK-DAG:   %[[MASK_BASE:.+]] = arith.constant 15 : i8
+//  CHECK-DAG:   %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
+//  CHECK-DAG:   %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
+//  CHECK-DAG:   %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
+//  CHECK-DAG:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
+//      CHECK:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<3xi8>) -> i8
+//      CHECK:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<3xi8>) -> i8
+//      CHECK:   return
+
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)>
+//      CHECK32: func @memref_store_i4(
+// CHECK32-SAME:     %[[ARG0:.+]]: index, %[[ARG1:.+]]: i4
+//  CHECK32-DAG:   %[[ALLOC:.+]] = memref.alloc() : memref<1xi32>
+//  CHECK32-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG1]] : i4 to i32
+//  CHECK32-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//  CHECK32-DAG:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+//  CHECK32-DAG:   %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+//  CHECK32-DAG:   %[[MASK_BASE:.+]] = arith.constant 15 : i32
+//  CHECK32-DAG:   %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I32]] : i32
+//  CHECK32-DAG:   %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
+//  CHECK32-DAG:   %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
+//  CHECK32-DAG:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
+//      CHECK32:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<1xi32>) -> i32
+//      CHECK32:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<1xi32>) -> i32
+//      CHECK32:   return
+
+// -----
+
+func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
+    %0 = memref.alloc() : memref<3x125xi4>
+    memref.assume_alignment %0, 64 : memref<3x125xi4>
+    memref.store %arg2, %0[%arg0,%arg1] : memref<3x125xi4>
+    return
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 2) * 8)>
+//      CHECK: func @memref_store_i4_rank2(
+// CHECK-SAME:     %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
+//  CHECK-DAG:   %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
+//  CHECK-DAG:   memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
+//  CHECK-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i8
+//  CHECK-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//  CHECK-DAG:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
+//  CHECK-DAG:   %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+//  CHECK-DAG:   %[[MASK_BASE:.+]] = arith.constant 15 : i8
+//  CHECK-DAG:   %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
+//  CHECK-DAG:   %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
+//  CHECK-DAG:   %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
+//  CHECK-DAG:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
+//      CHECK:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
+//      CHECK:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
+//      CHECK:   return
+
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 8)>
+//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 8) * 32)>
+//      CHECK32: func @memref_store_i4_rank2(
+// CHECK32-SAME:     %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
+//  CHECK32-DAG:   %[[ALLOC:.+]] = memref.alloc() : memref<47xi32>
+//  CHECK32-DAG:   memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
+//  CHECK32-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i32
+//  CHECK32-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//  CHECK32-DAG:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
+//  CHECK32-DAG:   %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+//  CHECK32-DAG:   %[[MASK_BASE:.+]] = arith.constant 15 : i32
+//  CHECK32-DAG:   %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I32]] : i32
+//  CHECK32-DAG:   %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
+//  CHECK32-DAG:   %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
+//  CHECK32-DAG:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
+//      CHECK32:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
+//      CHECK32:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
+//      CHECK32:   return
+
+// -----
+
+func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: i4) -> () {
+  %0 = memref.alloc(%arg0, %arg1) : memref<?x?xi4>
+  memref.store %arg4, %0[%arg2, %arg3] : memref<?x?xi4>
+  return
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
+//      CHECK: func @memref_store_i4_dynamic(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: i4
+//  CHECK-DAG:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//  CHECK-DAG:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
+//  CHECK-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i8
+//  CHECK-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+//  CHECK-DAG:   %[[BITOFFSET:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+//  CHECK-DAG:   %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+//  CHECK-DAG:   %[[MASK_BASE:.+]] = arith.constant 15 : i8
+//  CHECK-DAG:   %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
+//  CHECK-DAG:   %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
+//  CHECK-DAG:   %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
+//  CHECK-DAG:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
+//      CHECK:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<?xi8>) -> i8
+//      CHECK:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<?xi8>) -> i8
+//      CHECK:   return
+
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
+//  CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)>
+//      CHECK32: func @memref_store_i4_dynamic(
+// CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: i4
+//  CHECK32-DAG:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//  CHECK32-DAG:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
+//  CHECK32-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i32
+//  CHECK32-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+//  CHECK32-DAG:   %[[BITOFFSET:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+//  CHECK32-DAG:   %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+//  CHECK32-DAG:   %[[MASK_BASE:.+]] = arith.constant 15 : i32
+//  CHECK32-DAG:   %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I32]] : i32
+//  CHECK32-DAG:   %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
+//  CHECK32-DAG:   %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
+//  CHECK32-DAG:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
+//      CHECK32:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<?xi32>) -> i32
+//      CHECK32:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<?xi32>) -> i32
+//      CHECK32:   return
+
+// -----
+
+func.func @rank_zero_memref_store(%arg0: i4) -> () {
+  %0 = memref.alloc() : memref<i4>
+  memref.store %arg0, %0[] : memref<i4>
+  return
+}
+// CHECK-LABEL: func @rank_zero_memref
+//  CHECK-SAME:     %[[ARG0:.+]]: i4
+//       CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<i8>
+//       CHECK:   %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i8
+//       CHECK:   %[[MASK:.+]] = arith.constant -18 : i8
+//       CHECK:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][] : (i8, memref<i8>) -> i8
+//       CHECK:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[EXTUI]], %[[ALLOC]][] : (i8, memref<i8>) -> i8
+//       CHECK:   return
+
+// CHECK32-LABEL: func @rank_zero_memref
+//  CHECK32-SAME:     %[[ARG0:.+]]: i4
+//       CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<i32>
+//       CHECK32:   %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i32
+//       CHECK32:   %[[MASK:.+]] = arith.constant -18 : i32
+//       CHECK32:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
+//       CHECK32:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[EXTUI]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
+//       CHECK32:   return



More information about the Mlir-commits mailing list