[Mlir-commits] [mlir] [mlir] Add subbyte emulation support for `memref.store`. (PR #73174)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 28 07:11:53 PST 2023


https://github.com/Max191 updated https://github.com/llvm/llvm-project/pull/73174

>From 229aa00408d79f0d1700287373574a6a6aa07bad Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 10 Nov 2023 17:06:43 -0500
Subject: [PATCH] [mlir] Add subbyte emulation support for `memref.store`. This
 adds a conversion for narrow type emulation of memref.store ops. The
 conversion replaces the memref.store with two memref.atomic_rmw ops. Atomics
 are used to prevent race conditions on same-byte accesses, in the event that
 two threads are storing into the same byte.

---
 .../MemRef/Transforms/EmulateNarrowType.cpp   | 158 +++++++++++++---
 .../Dialect/MemRef/emulate-narrow-type.mlir   | 169 ++++++++++++++++++
 2 files changed, 300 insertions(+), 27 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index e5801c3733ed5a8..ec94ecdcf62502b 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -17,7 +17,9 @@
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -102,13 +104,64 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
   AffineExpr s0;
   bindSymbols(builder.getContext(), s0);
   int scaleFactor = targetBits / sourceBits;
-  OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply(
-      builder, loc, (s0 % scaleFactor) * sourceBits, {srcIdx});
+  AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
+  OpFoldResult offsetVal =
+      affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx});
   Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
   IntegerType dstType = builder.getIntegerType(targetBits);
   return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
 }
 
+/// When writing a subbyte size, masked bitwise operations are used to only
+/// modify the relevant bits. This function returns an and mask for clearing
+/// the destination bits in a subbyte write. E.g., when writing to the second
+/// i4 in an i32, 0xFFFFFF0F is created.
+static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices,
+                                 int64_t srcBits, int64_t dstBits,
+                                 Value bitwidthOffset, OpBuilder &builder) {
+  auto dstIntegerType = builder.getIntegerType(dstBits);
+  auto maskRightAlignedAttr =
+      builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
+  Value maskRightAligned = builder.create<arith::ConstantOp>(
+      loc, dstIntegerType, maskRightAlignedAttr);
+  Value writeMaskInverse =
+      builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
+  auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
+  Value flipVal =
+      builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr);
+  return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
+}
+
+/// Returns the scaled linearized index based on the `srcBits` and `dstBits`
+/// sizes. The input `linearizedIndex` has the granularity of `srcBits`, and
+/// the returned index has the granularity of `dstBits`
+static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc,
+                                      OpFoldResult linearizedIndex,
+                                      int64_t srcBits, int64_t dstBits) {
+  AffineExpr s0;
+  bindSymbols(builder.getContext(), s0);
+  int64_t scaler = dstBits / srcBits;
+  OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply(
+      builder, loc, s0.floorDiv(scaler), {linearizedIndex});
+  return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices);
+}
+
+static OpFoldResult
+getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
+                        const SmallVector<OpFoldResult> &indices,
+                        Value memref) {
+  auto stridedMetadata =
+      builder.create<memref::ExtractStridedMetadataOp>(loc, memref);
+  OpFoldResult linearizedIndices;
+  std::tie(std::ignore, linearizedIndices) =
+      memref::getLinearizedMemRefOffsetAndSize(
+          builder, loc, srcBits, srcBits,
+          stridedMetadata.getConstifiedMixedOffset(),
+          stridedMetadata.getConstifiedMixedSizes(),
+          stridedMetadata.getConstifiedMixedStrides(), indices);
+  return linearizedIndices;
+}
+
 namespace {
 
 //===----------------------------------------------------------------------===//
@@ -218,32 +271,15 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
       bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
                                                  ValueRange{});
     } else {
-      SmallVector<OpFoldResult> indices =
-          getAsOpFoldResult(adaptor.getIndices());
-
-      auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
-          loc, op.getMemRef());
-
       // Linearize the indices of the original load instruction. Do not account
       // for the scaling yet. This will be accounted for later.
-      OpFoldResult linearizedIndices;
-      std::tie(std::ignore, linearizedIndices) =
-          memref::getLinearizedMemRefOffsetAndSize(
-              rewriter, loc, srcBits, srcBits,
-              stridedMetadata.getConstifiedMixedOffset(),
-              stridedMetadata.getConstifiedMixedSizes(),
-              stridedMetadata.getConstifiedMixedStrides(), indices);
-
-      AffineExpr s0;
-      bindSymbols(rewriter.getContext(), s0);
-      int64_t scaler = dstBits / srcBits;
-      OpFoldResult scaledLinearizedIndices =
-          affine::makeComposedFoldedAffineApply(
-              rewriter, loc, s0.floorDiv(scaler), {linearizedIndices});
+      OpFoldResult linearizedIndices = getLinearizedSrcIndices(
+          rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
+
       Value newLoad = rewriter.create<memref::LoadOp>(
           loc, adaptor.getMemref(),
-          getValueOrCreateConstantIndexOp(rewriter, loc,
-                                          scaledLinearizedIndices));
+          getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
+                                   dstBits));
 
       // Get the offset and shift the bits to the rightmost.
       // Note, currently only the big-endian is supported.
@@ -305,6 +341,74 @@ struct ConvertMemRefReinterpretCast final
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ConvertMemrefStore
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
+    int srcBits = op.getMemRefType().getElementTypeBitWidth();
+    int dstBits = convertedType.getElementTypeBitWidth();
+    auto dstIntegerType = rewriter.getIntegerType(dstBits);
+    if (dstBits % srcBits != 0) {
+      return rewriter.notifyMatchFailure(
+          op, "only dstBits % srcBits == 0 supported");
+    }
+
+    Location loc = op.getLoc();
+    Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
+                                                          adaptor.getValue());
+
+    // Special case 0-rank memref stores. We compute the mask at compile time.
+    if (convertedType.getRank() == 0) {
+      // Create mask to clear destination bits
+      auto writeMaskValAttr =
+          rewriter.getIntegerAttr(dstIntegerType, ~(1 << (srcBits)) - 1);
+      Value writeMask = rewriter.create<arith::ConstantOp>(loc, dstIntegerType,
+                                                           writeMaskValAttr);
+
+      // Clear destination bits
+      rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
+                                           writeMask, adaptor.getMemref(),
+                                           ValueRange{});
+      // Write srcs bits to destination
+      rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
+                                           extendedInput, adaptor.getMemref(),
+                                           ValueRange{});
+      rewriter.eraseOp(op);
+      return success();
+    }
+
+    OpFoldResult linearizedIndices = getLinearizedSrcIndices(
+        rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
+    Value storeIndices = getIndicesForLoadOrStore(
+        rewriter, loc, linearizedIndices, srcBits, dstBits);
+    Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits,
+                                                dstBits, rewriter);
+    Value writeMask = getSubByteWriteMask(loc, linearizedIndices, srcBits,
+                                          dstBits, bitwidthOffset, rewriter);
+    // Align the value to write with the destination bits
+    Value alignedVal =
+        rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset);
+
+    // Clear destination bits
+    rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
+                                         writeMask, adaptor.getMemref(),
+                                         storeIndices);
+    // Write srcs bits to destination
+    rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
+                                         alignedVal, adaptor.getMemref(),
+                                         storeIndices);
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ConvertMemRefSubview
 //===----------------------------------------------------------------------===//
@@ -350,9 +454,9 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
   // Populate `memref.*` conversion patterns.
   patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
                ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefLoad,
-               ConvertMemRefAssumeAlignment, ConvertMemRefSubview,
-               ConvertMemRefReinterpretCast>(typeConverter,
-                                             patterns.getContext());
+               ConvertMemrefStore, ConvertMemRefAssumeAlignment,
+               ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
+      typeConverter, patterns.getContext());
   memref::populateResolveExtractStridedMetadataPatterns(patterns);
 }
 
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index dc32a59a1a14931..4c88fcc3d656355 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -265,3 +265,172 @@ func.func @memref_alloca_load_i4(%arg0: index) -> i4 {
 //      CHECK32:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
 //      CHECK32:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
 //      CHECK32:   return %[[TRUNC]]
+
+// -----
+
+func.func @memref_store_i4(%arg0: index, %arg1: i4) -> () {
+    %0 = memref.alloc() : memref<5xi4>
+    memref.store %arg1, %0[%arg0] : memref<5xi4>
+    return
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)>
+//      CHECK: func @memref_store_i4(
+// CHECK-SAME:     %[[ARG0:.+]]: index, %[[ARG1:.+]]: i4
+//  CHECK-DAG:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+//  CHECK-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG1]] : i4 to i8
+//  CHECK-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//  CHECK-DAG:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+//  CHECK-DAG:   %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+//  CHECK-DAG:   %[[MASK_BASE:.+]] = arith.constant 15 : i8
+//  CHECK-DAG:   %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
+//  CHECK-DAG:   %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
+//  CHECK-DAG:   %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
+//  CHECK-DAG:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
+//      CHECK:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<3xi8>) -> i8
+//      CHECK:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<3xi8>) -> i8
+//      CHECK:   return
+
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)>
+//      CHECK32: func @memref_store_i4(
+// CHECK32-SAME:     %[[ARG0:.+]]: index, %[[ARG1:.+]]: i4
+//  CHECK32-DAG:   %[[ALLOC:.+]] = memref.alloc() : memref<1xi32>
+//  CHECK32-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG1]] : i4 to i32
+//  CHECK32-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//  CHECK32-DAG:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+//  CHECK32-DAG:   %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+//  CHECK32-DAG:   %[[MASK_BASE:.+]] = arith.constant 15 : i32
+//  CHECK32-DAG:   %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I32]] : i32
+//  CHECK32-DAG:   %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
+//  CHECK32-DAG:   %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
+//  CHECK32-DAG:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
+//      CHECK32:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<1xi32>) -> i32
+//      CHECK32:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<1xi32>) -> i32
+//      CHECK32:   return
+
+// -----
+
+func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
+    %0 = memref.alloc() : memref<3x125xi4>
+    memref.assume_alignment %0, 64 : memref<3x125xi4>
+    memref.store %arg2, %0[%arg0,%arg1] : memref<3x125xi4>
+    return
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 2) * 8)>
+//      CHECK: func @memref_store_i4_rank2(
+// CHECK-SAME:     %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
+//  CHECK-DAG:   %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
+//  CHECK-DAG:   memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
+//  CHECK-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i8
+//  CHECK-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//  CHECK-DAG:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
+//  CHECK-DAG:   %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+//  CHECK-DAG:   %[[MASK_BASE:.+]] = arith.constant 15 : i8
+//  CHECK-DAG:   %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
+//  CHECK-DAG:   %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
+//  CHECK-DAG:   %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
+//  CHECK-DAG:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
+//      CHECK:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
+//      CHECK:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
+//      CHECK:   return
+
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 8)>
+//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 8) * 32)>
+//      CHECK32: func @memref_store_i4_rank2(
+// CHECK32-SAME:     %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
+//  CHECK32-DAG:   %[[ALLOC:.+]] = memref.alloc() : memref<47xi32>
+//  CHECK32-DAG:   memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
+//  CHECK32-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i32
+//  CHECK32-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//  CHECK32-DAG:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
+//  CHECK32-DAG:   %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+//  CHECK32-DAG:   %[[MASK_BASE:.+]] = arith.constant 15 : i32
+//  CHECK32-DAG:   %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I32]] : i32
+//  CHECK32-DAG:   %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
+//  CHECK32-DAG:   %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
+//  CHECK32-DAG:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
+//      CHECK32:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
+//      CHECK32:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
+//      CHECK32:   return
+
+// -----
+
+func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: i4) -> () {
+  %0 = memref.alloc(%arg0, %arg1) : memref<?x?xi4>
+  memref.store %arg4, %0[%arg2, %arg3] : memref<?x?xi4>
+  return
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
+//      CHECK: func @memref_store_i4_dynamic(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: i4
+//  CHECK-DAG:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//  CHECK-DAG:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
+//  CHECK-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i8
+//  CHECK-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+//  CHECK-DAG:   %[[BITOFFSET:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+//  CHECK-DAG:   %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+//  CHECK-DAG:   %[[MASK_BASE:.+]] = arith.constant 15 : i8
+//  CHECK-DAG:   %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
+//  CHECK-DAG:   %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
+//  CHECK-DAG:   %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
+//  CHECK-DAG:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
+//      CHECK:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<?xi8>) -> i8
+//      CHECK:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<?xi8>) -> i8
+//      CHECK:   return
+
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
+//  CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)>
+//      CHECK32: func @memref_store_i4_dynamic(
+// CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: i4
+//  CHECK32-DAG:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//  CHECK32-DAG:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
+//  CHECK32-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i32
+//  CHECK32-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+//  CHECK32-DAG:   %[[BITOFFSET:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+//  CHECK32-DAG:   %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+//  CHECK32-DAG:   %[[MASK_BASE:.+]] = arith.constant 15 : i32
+//  CHECK32-DAG:   %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I32]] : i32
+//  CHECK32-DAG:   %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
+//  CHECK32-DAG:   %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
+//  CHECK32-DAG:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
+//      CHECK32:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<?xi32>) -> i32
+//      CHECK32:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<?xi32>) -> i32
+//      CHECK32:   return
+
+// -----
+
+func.func @rank_zero_memref_store(%arg0: i4) -> () {
+  %0 = memref.alloc() : memref<i4>
+  memref.store %arg0, %0[] : memref<i4>
+  return
+}
+// CHECK-LABEL: func @rank_zero_memref
+//  CHECK-SAME:     %[[ARG0:.+]]: i4
+//       CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<i8>
+//       CHECK:   %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i8
+//       CHECK:   %[[MASK:.+]] = arith.constant -18 : i8
+//       CHECK:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][] : (i8, memref<i8>) -> i8
+//       CHECK:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[EXTUI]], %[[ALLOC]][] : (i8, memref<i8>) -> i8
+//       CHECK:   return
+
+// CHECK32-LABEL: func @rank_zero_memref
+//  CHECK32-SAME:     %[[ARG0:.+]]: i4
+//       CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<i32>
+//       CHECK32:   %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i32
+//       CHECK32:   %[[MASK:.+]] = arith.constant -18 : i32
+//       CHECK32:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
+//       CHECK32:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[EXTUI]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
+//       CHECK32:   return



More information about the Mlir-commits mailing list