[Mlir-commits] [mlir] [mlir] Add subbyte emulation support for `memref.store`. (PR #72004)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 22 13:45:18 PST 2023
https://github.com/Max191 updated https://github.com/llvm/llvm-project/pull/72004
>From f1c15b5850998c39e1ab6438e7c9a4ec4d581137 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 10 Nov 2023 17:06:43 -0500
Subject: [PATCH] [mlir] Add subbyte emulation support for `memref.store`. This
adds a conversion for narrow type emulation of memref.store ops. The
conversion replaces the memref.store with two memref.atomic_rmw ops. Atomics
are used to prevent race conditions on same-byte accesses, in the event that
two threads are storing into the same byte.
---
.../MemRef/Transforms/EmulateNarrowType.cpp | 164 ++++++++++++++---
.../Dialect/MemRef/emulate-narrow-type.mlir | 169 ++++++++++++++++++
2 files changed, 308 insertions(+), 25 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 9f58e9055acadbb..553af2adb60f3d6 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -17,6 +17,9 @@
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"
@@ -29,6 +32,26 @@ using namespace mlir;
// Utility functions
//===----------------------------------------------------------------------===//
+/// Replaces the memref::StoreOp with two new memref::AtomicRMWOps. The first
+/// memref::AtomicRMWOp sets the destination bits to all zero to prepare the
+/// destination byte to be written to. The second memref::AtomicRMWOp does the
+/// writing of the value to store, using an `ori` type operation. The value
+/// to store and the write mask should both have the destination type bitwidth,
+/// and the bits of the value to store should be all zero except for the bits
+/// aligned with the store destination.
+static void replaceStoreWithAtomics(ConversionPatternRewriter &rewriter,
+ memref::StoreOp op, Value writeMask,
+ Value storeVal, Value memref,
+ ValueRange storeIndices) {
+ // Clear destination bits
+ rewriter.create<memref::AtomicRMWOp>(op.getLoc(), arith::AtomicRMWKind::andi,
+ writeMask, memref, storeIndices);
+ // Write srcs bits to destination
+ rewriter.create<memref::AtomicRMWOp>(op->getLoc(), arith::AtomicRMWKind::ori,
+ storeVal, memref, storeIndices);
+ rewriter.eraseOp(op);
+}
+
/// When data is loaded/stored in `targetBits` granularity, but is used in
/// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
/// treated as an array of elements of width `sourceBits`.
@@ -43,13 +66,67 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
AffineExpr s0;
bindSymbols(builder.getContext(), s0);
int scaleFactor = targetBits / sourceBits;
- OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply(
- builder, loc, (s0 % scaleFactor) * sourceBits, {srcIdx});
+ AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
+ OpFoldResult offsetVal =
+ affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx});
Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
IntegerType dstType = builder.getIntegerType(targetBits);
return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
}
+/// When writing a subbyte size, writing needs to happen atomically in case of
+/// another write happening on the same byte at the same time. To do the write,
+/// we first must clear `dstBits` at the `linearizedIndices` of the subbyte
+/// store. This function returns the appropriate mask for clearing these bits.
+static Value getAtomicWriteMask(Location loc, OpFoldResult linearizedIndices,
+ int64_t srcBits, int64_t dstBits,
+ Value bitwidthOffset, OpBuilder &builder) {
+ auto dstIntegerType = builder.getIntegerType(dstBits);
+ auto maskRightAlignedAttr =
+ builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
+ Value maskRightAligned =
+ builder
+ .create<arith::ConstantOp>(loc, dstIntegerType, maskRightAlignedAttr)
+ .getResult();
+ Value writeMaskInverse =
+ builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
+ auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
+ Value flipVal =
+ builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr)
+ .getResult();
+ return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
+}
+
+/// Returns the scaled linearized index based on the `srcBits` and `dstBits`
+/// sizes. The input `linearizedIndex` has the grandularity of `srcBits`, and
+/// the returned index has the granularity of `dstBits`
+static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc,
+ OpFoldResult linearizedIndex,
+ int64_t srcBits, int64_t dstBits) {
+ AffineExpr s0;
+ bindSymbols(builder.getContext(), s0);
+ int64_t scaler = dstBits / srcBits;
+ OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply(
+ builder, loc, s0.floorDiv(scaler), {linearizedIndex});
+ return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices);
+}
+
+static OpFoldResult
+getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
+ const SmallVector<OpFoldResult> &indices,
+ Value memref) {
+ auto stridedMetadata =
+ builder.create<memref::ExtractStridedMetadataOp>(loc, memref);
+ OpFoldResult linearizedIndices;
+ std::tie(std::ignore, linearizedIndices) =
+ memref::getLinearizedMemRefOffsetAndSize(
+ builder, loc, srcBits, srcBits,
+ stridedMetadata.getConstifiedMixedOffset(),
+ stridedMetadata.getConstifiedMixedSizes(),
+ stridedMetadata.getConstifiedMixedStrides(), indices);
+ return linearizedIndices;
+}
+
namespace {
//===----------------------------------------------------------------------===//
@@ -155,32 +232,15 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
ValueRange{});
} else {
- SmallVector<OpFoldResult> indices =
- getAsOpFoldResult(adaptor.getIndices());
-
- auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
- loc, op.getMemRef());
-
// Linearize the indices of the original load instruction. Do not account
// for the scaling yet. This will be accounted for later.
- OpFoldResult linearizedIndices;
- std::tie(std::ignore, linearizedIndices) =
- memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, srcBits,
- stridedMetadata.getConstifiedMixedOffset(),
- stridedMetadata.getConstifiedMixedSizes(),
- stridedMetadata.getConstifiedMixedStrides(), indices);
-
- AffineExpr s0;
- bindSymbols(rewriter.getContext(), s0);
- int64_t scaler = dstBits / srcBits;
- OpFoldResult scaledLinearizedIndices =
- affine::makeComposedFoldedAffineApply(
- rewriter, loc, s0.floorDiv(scaler), {linearizedIndices});
+ OpFoldResult linearizedIndices = getLinearizedSrcIndices(
+ rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
+
Value newLoad = rewriter.create<memref::LoadOp>(
loc, adaptor.getMemref(),
- getValueOrCreateConstantIndexOp(rewriter, loc,
- scaledLinearizedIndices));
+ getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
+ dstBits));
// Get the offset and shift the bits to the rightmost.
// Note, currently only the big-endian is supported.
@@ -211,6 +271,60 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertMemrefStore
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
+ int srcBits = op.getMemRefType().getElementTypeBitWidth();
+ int dstBits = convertedType.getElementTypeBitWidth();
+ auto dstIntegerType = rewriter.getIntegerType(dstBits);
+ if (dstBits % srcBits != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "only dstBits % srcBits == 0 supported");
+ }
+
+ Location loc = op.getLoc();
+ Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
+ adaptor.getValue());
+
+ // Special case 0-rank memref stores. We can compute the mask at compile
+ // time.
+ if (convertedType.getRank() == 0) {
+ // Create mask to clear destination bits
+ auto writeMaskValAttr =
+ rewriter.getIntegerAttr(dstIntegerType, ~(1 << (srcBits)) - 1);
+ Value writeMask = rewriter.create<arith::ConstantOp>(loc, dstIntegerType,
+ writeMaskValAttr);
+
+ replaceStoreWithAtomics(rewriter, op, writeMask, extendedInput,
+ adaptor.getMemref(), ValueRange{});
+ return success();
+ }
+
+ OpFoldResult linearizedIndices = getLinearizedSrcIndices(
+ rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
+ Value storeIndices = getIndicesForLoadOrStore(
+ rewriter, loc, linearizedIndices, srcBits, dstBits);
+ Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits,
+ dstBits, rewriter);
+ Value writeMask = getAtomicWriteMask(loc, linearizedIndices, srcBits,
+ dstBits, bitwidthOffset, rewriter);
+ // Align the value to write with the destination bits
+ Value alignedVal =
+ rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset);
+ replaceStoreWithAtomics(rewriter, op, writeMask, alignedVal,
+ adaptor.getMemref(), storeIndices);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertMemRefSubview
//===----------------------------------------------------------------------===//
@@ -292,7 +406,7 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
// Populate `memref.*` conversion patterns.
patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
- ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
+ ConvertMemRefAssumeAlignment, ConvertMemRefSubview, ConvertMemrefStore>(
typeConverter, patterns.getContext());
memref::populateResolveExtractStridedMetadataPatterns(patterns);
}
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 6ed97f05aa7cff2..22c5947fd2ac97b 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -174,3 +174,172 @@ func.func @memref_strided_i4(%idx : index) -> i4 {
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<16xi32>
// CHECK32: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][4] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: 4>>
// CHECK32: %[[LOAD:.+]] = memref.load %[[SUBVIEW]]
+
+// -----
+
+func.func @memref_store_i4(%arg0: index, %arg1: i4) -> () {
+ %0 = memref.alloc() : memref<5xi4>
+ memref.store %arg1, %0[%arg0] : memref<5xi4>
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)>
+// CHECK: func @memref_store_i4(
+// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: i4
+// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG1]] : i4 to i8
+// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK-DAG: %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+// CHECK-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i8
+// CHECK-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
+// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
+// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
+// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<3xi8>) -> i8
+// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<3xi8>) -> i8
+// CHECK: return
+
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)>
+// CHECK32: func @memref_store_i4(
+// CHECK32-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: i4
+// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<1xi32>
+// CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG1]] : i4 to i32
+// CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK32-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK32-DAG: %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+// CHECK32-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i32
+// CHECK32-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I32]] : i32
+// CHECK32-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
+// CHECK32-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
+// CHECK32-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
+// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<1xi32>) -> i32
+// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<1xi32>) -> i32
+// CHECK32: return
+
+// -----
+
+func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
+ %0 = memref.alloc() : memref<3x125xi4>
+ memref.assume_alignment %0, 64 : memref<3x125xi4>
+ memref.store %arg2, %0[%arg0,%arg1] : memref<3x125xi4>
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 2) * 8)>
+// CHECK: func @memref_store_i4_rank2(
+// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
+// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
+// CHECK-DAG: memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
+// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i8
+// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK-DAG: %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+// CHECK-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i8
+// CHECK-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
+// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
+// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
+// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
+// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
+// CHECK: return
+
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 8)>
+// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 8) * 32)>
+// CHECK32: func @memref_store_i4_rank2(
+// CHECK32-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
+// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<47xi32>
+// CHECK32-DAG: memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
+// CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i32
+// CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32-DAG: %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+// CHECK32-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i32
+// CHECK32-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I32]] : i32
+// CHECK32-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
+// CHECK32-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
+// CHECK32-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
+// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
+// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
+// CHECK32: return
+
+// -----
+
+func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: i4) -> () {
+ %0 = memref.alloc(%arg0, %arg1) : memref<?x?xi4>
+ memref.store %arg4, %0[%arg2, %arg3] : memref<?x?xi4>
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
+// CHECK: func @memref_store_i4_dynamic(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i4
+// CHECK-DAG: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
+// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i8
+// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+// CHECK-DAG: %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+// CHECK-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i8
+// CHECK-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
+// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
+// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
+// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<?xi8>) -> i8
+// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<?xi8>) -> i8
+// CHECK: return
+
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
+// CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)>
+// CHECK32: func @memref_store_i4_dynamic(
+// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i4
+// CHECK32-DAG: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
+// CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i32
+// CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+// CHECK32-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+// CHECK32-DAG: %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+// CHECK32-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i32
+// CHECK32-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I32]] : i32
+// CHECK32-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
+// CHECK32-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
+// CHECK32-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
+// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<?xi32>) -> i32
+// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<?xi32>) -> i32
+// CHECK32: return
+
+// -----
+
+func.func @rank_zero_memref_store(%arg0: i4) -> () {
+ %0 = memref.alloc() : memref<i4>
+ memref.store %arg0, %0[] : memref<i4>
+ return
+}
+// CHECK-LABEL: func @rank_zero_memref
+// CHECK-SAME: %[[ARG0:.+]]: i4
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<i8>
+// CHECK: %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i8
+// CHECK: %[[MASK:.+]] = arith.constant -18 : i8
+// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][] : (i8, memref<i8>) -> i8
+// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[EXTUI]], %[[ALLOC]][] : (i8, memref<i8>) -> i8
+// CHECK: return
+
+// CHECK32-LABEL: func @rank_zero_memref
+// CHECK32-SAME: %[[ARG0:.+]]: i4
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<i32>
+// CHECK32: %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i32
+// CHECK32: %[[MASK:.+]] = arith.constant -18 : i32
+// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
+// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[EXTUI]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
+// CHECK32: return
More information about the Mlir-commits
mailing list