[Mlir-commits] [mlir] [mlir] Add subbyte emulation support for `memref.store`. (PR #73174)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 28 11:26:31 PST 2023
https://github.com/Max191 updated https://github.com/llvm/llvm-project/pull/73174
>From 229aa00408d79f0d1700287373574a6a6aa07bad Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 10 Nov 2023 17:06:43 -0500
Subject: [PATCH 1/3] [mlir] Add subbyte emulation support for `memref.store`.
This adds a conversion for narrow type emulation of memref.store ops. The
conversion replaces the memref.store with two memref.atomic_rmw ops. Atomics
are used to prevent race conditions on same-byte accesses, in the event that
two threads are storing into the same byte.
---
.../MemRef/Transforms/EmulateNarrowType.cpp | 158 +++++++++++++---
.../Dialect/MemRef/emulate-narrow-type.mlir | 169 ++++++++++++++++++
2 files changed, 300 insertions(+), 27 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index e5801c3733ed5a8..ec94ecdcf62502b 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -17,7 +17,9 @@
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -102,13 +104,64 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
AffineExpr s0;
bindSymbols(builder.getContext(), s0);
int scaleFactor = targetBits / sourceBits;
- OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply(
- builder, loc, (s0 % scaleFactor) * sourceBits, {srcIdx});
+ AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
+ OpFoldResult offsetVal =
+ affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx});
Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
IntegerType dstType = builder.getIntegerType(targetBits);
return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
}
+/// When writing a subbyte size, masked bitwise operations are used to only
+/// modify the relevant bits. This function returns an and mask for clearing
+/// the destination bits in a subbyte write. E.g., when writing to the second
+/// i4 in an i32, 0xFFFFFF0F is created.
+static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices,
+ int64_t srcBits, int64_t dstBits,
+ Value bitwidthOffset, OpBuilder &builder) {
+ auto dstIntegerType = builder.getIntegerType(dstBits);
+ auto maskRightAlignedAttr =
+ builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
+ Value maskRightAligned = builder.create<arith::ConstantOp>(
+ loc, dstIntegerType, maskRightAlignedAttr);
+ Value writeMaskInverse =
+ builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
+ auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
+ Value flipVal =
+ builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr);
+ return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
+}
+
+/// Returns the scaled linearized index based on the `srcBits` and `dstBits`
+/// sizes. The input `linearizedIndex` has the granularity of `srcBits`, and
+/// the returned index has the granularity of `dstBits`
+static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc,
+ OpFoldResult linearizedIndex,
+ int64_t srcBits, int64_t dstBits) {
+ AffineExpr s0;
+ bindSymbols(builder.getContext(), s0);
+ int64_t scaler = dstBits / srcBits;
+ OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply(
+ builder, loc, s0.floorDiv(scaler), {linearizedIndex});
+ return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices);
+}
+
+static OpFoldResult
+getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
+ const SmallVector<OpFoldResult> &indices,
+ Value memref) {
+ auto stridedMetadata =
+ builder.create<memref::ExtractStridedMetadataOp>(loc, memref);
+ OpFoldResult linearizedIndices;
+ std::tie(std::ignore, linearizedIndices) =
+ memref::getLinearizedMemRefOffsetAndSize(
+ builder, loc, srcBits, srcBits,
+ stridedMetadata.getConstifiedMixedOffset(),
+ stridedMetadata.getConstifiedMixedSizes(),
+ stridedMetadata.getConstifiedMixedStrides(), indices);
+ return linearizedIndices;
+}
+
namespace {
//===----------------------------------------------------------------------===//
@@ -218,32 +271,15 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
ValueRange{});
} else {
- SmallVector<OpFoldResult> indices =
- getAsOpFoldResult(adaptor.getIndices());
-
- auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
- loc, op.getMemRef());
-
// Linearize the indices of the original load instruction. Do not account
// for the scaling yet. This will be accounted for later.
- OpFoldResult linearizedIndices;
- std::tie(std::ignore, linearizedIndices) =
- memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, srcBits,
- stridedMetadata.getConstifiedMixedOffset(),
- stridedMetadata.getConstifiedMixedSizes(),
- stridedMetadata.getConstifiedMixedStrides(), indices);
-
- AffineExpr s0;
- bindSymbols(rewriter.getContext(), s0);
- int64_t scaler = dstBits / srcBits;
- OpFoldResult scaledLinearizedIndices =
- affine::makeComposedFoldedAffineApply(
- rewriter, loc, s0.floorDiv(scaler), {linearizedIndices});
+ OpFoldResult linearizedIndices = getLinearizedSrcIndices(
+ rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
+
Value newLoad = rewriter.create<memref::LoadOp>(
loc, adaptor.getMemref(),
- getValueOrCreateConstantIndexOp(rewriter, loc,
- scaledLinearizedIndices));
+ getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
+ dstBits));
// Get the offset and shift the bits to the rightmost.
// Note, currently only the big-endian is supported.
@@ -305,6 +341,74 @@ struct ConvertMemRefReinterpretCast final
}
};
+//===----------------------------------------------------------------------===//
+// ConvertMemrefStore
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
+ int srcBits = op.getMemRefType().getElementTypeBitWidth();
+ int dstBits = convertedType.getElementTypeBitWidth();
+ auto dstIntegerType = rewriter.getIntegerType(dstBits);
+ if (dstBits % srcBits != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "only dstBits % srcBits == 0 supported");
+ }
+
+ Location loc = op.getLoc();
+ Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
+ adaptor.getValue());
+
+ // Special case 0-rank memref stores. We compute the mask at compile time.
+ if (convertedType.getRank() == 0) {
+ // Create mask to clear destination bits
+ auto writeMaskValAttr =
+ rewriter.getIntegerAttr(dstIntegerType, ~(1 << (srcBits)) - 1);
+ Value writeMask = rewriter.create<arith::ConstantOp>(loc, dstIntegerType,
+ writeMaskValAttr);
+
+ // Clear destination bits
+ rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
+ writeMask, adaptor.getMemref(),
+ ValueRange{});
+ // Write srcs bits to destination
+ rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
+ extendedInput, adaptor.getMemref(),
+ ValueRange{});
+ rewriter.eraseOp(op);
+ return success();
+ }
+
+ OpFoldResult linearizedIndices = getLinearizedSrcIndices(
+ rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
+ Value storeIndices = getIndicesForLoadOrStore(
+ rewriter, loc, linearizedIndices, srcBits, dstBits);
+ Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits,
+ dstBits, rewriter);
+ Value writeMask = getSubByteWriteMask(loc, linearizedIndices, srcBits,
+ dstBits, bitwidthOffset, rewriter);
+ // Align the value to write with the destination bits
+ Value alignedVal =
+ rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset);
+
+ // Clear destination bits
+ rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
+ writeMask, adaptor.getMemref(),
+ storeIndices);
+ // Write srcs bits to destination
+ rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
+ alignedVal, adaptor.getMemref(),
+ storeIndices);
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertMemRefSubview
//===----------------------------------------------------------------------===//
@@ -350,9 +454,9 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
// Populate `memref.*` conversion patterns.
patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefLoad,
- ConvertMemRefAssumeAlignment, ConvertMemRefSubview,
- ConvertMemRefReinterpretCast>(typeConverter,
- patterns.getContext());
+ ConvertMemrefStore, ConvertMemRefAssumeAlignment,
+ ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
+ typeConverter, patterns.getContext());
memref::populateResolveExtractStridedMetadataPatterns(patterns);
}
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index dc32a59a1a14931..4c88fcc3d656355 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -265,3 +265,172 @@ func.func @memref_alloca_load_i4(%arg0: index) -> i4 {
// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
// CHECK32: return %[[TRUNC]]
+
+// -----
+
+func.func @memref_store_i4(%arg0: index, %arg1: i4) -> () {
+ %0 = memref.alloc() : memref<5xi4>
+ memref.store %arg1, %0[%arg0] : memref<5xi4>
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)>
+// CHECK: func @memref_store_i4(
+// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: i4
+// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG1]] : i4 to i8
+// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK-DAG: %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+// CHECK-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i8
+// CHECK-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
+// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
+// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
+// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<3xi8>) -> i8
+// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<3xi8>) -> i8
+// CHECK: return
+
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)>
+// CHECK32: func @memref_store_i4(
+// CHECK32-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: i4
+// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<1xi32>
+// CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG1]] : i4 to i32
+// CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK32-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK32-DAG: %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+// CHECK32-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i32
+// CHECK32-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I32]] : i32
+// CHECK32-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
+// CHECK32-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
+// CHECK32-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
+// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<1xi32>) -> i32
+// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<1xi32>) -> i32
+// CHECK32: return
+
+// -----
+
+func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
+ %0 = memref.alloc() : memref<3x125xi4>
+ memref.assume_alignment %0, 64 : memref<3x125xi4>
+ memref.store %arg2, %0[%arg0,%arg1] : memref<3x125xi4>
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 2) * 8)>
+// CHECK: func @memref_store_i4_rank2(
+// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
+// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
+// CHECK-DAG: memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
+// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i8
+// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK-DAG: %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+// CHECK-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i8
+// CHECK-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
+// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
+// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
+// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
+// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
+// CHECK: return
+
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 8)>
+// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 8) * 32)>
+// CHECK32: func @memref_store_i4_rank2(
+// CHECK32-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
+// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<47xi32>
+// CHECK32-DAG: memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
+// CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i32
+// CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32-DAG: %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+// CHECK32-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i32
+// CHECK32-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I32]] : i32
+// CHECK32-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
+// CHECK32-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
+// CHECK32-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
+// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
+// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
+// CHECK32: return
+
+// -----
+
+func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: i4) -> () {
+ %0 = memref.alloc(%arg0, %arg1) : memref<?x?xi4>
+ memref.store %arg4, %0[%arg2, %arg3] : memref<?x?xi4>
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
+// CHECK: func @memref_store_i4_dynamic(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i4
+// CHECK-DAG: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
+// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i8
+// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+// CHECK-DAG: %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+// CHECK-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i8
+// CHECK-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
+// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
+// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
+// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<?xi8>) -> i8
+// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<?xi8>) -> i8
+// CHECK: return
+
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
+// CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)>
+// CHECK32: func @memref_store_i4_dynamic(
+// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i4
+// CHECK32-DAG: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
+// CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i32
+// CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+// CHECK32-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+// CHECK32-DAG: %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+// CHECK32-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i32
+// CHECK32-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I32]] : i32
+// CHECK32-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
+// CHECK32-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
+// CHECK32-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
+// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<?xi32>) -> i32
+// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<?xi32>) -> i32
+// CHECK32: return
+
+// -----
+
+func.func @rank_zero_memref_store(%arg0: i4) -> () {
+ %0 = memref.alloc() : memref<i4>
+ memref.store %arg0, %0[] : memref<i4>
+ return
+}
+// CHECK-LABEL: func @rank_zero_memref
+// CHECK-SAME: %[[ARG0:.+]]: i4
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<i8>
+// CHECK: %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i8
+// CHECK: %[[MASK:.+]] = arith.constant -18 : i8
+// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][] : (i8, memref<i8>) -> i8
+// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[EXTUI]], %[[ALLOC]][] : (i8, memref<i8>) -> i8
+// CHECK: return
+
+// CHECK32-LABEL: func @rank_zero_memref
+// CHECK32-SAME: %[[ARG0:.+]]: i4
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<i32>
+// CHECK32: %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i32
+// CHECK32: %[[MASK:.+]] = arith.constant -18 : i32
+// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
+// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[EXTUI]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
+// CHECK32: return
>From e99593e0754bdd55f0b30f32ba2cd6f5bd962e0c Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Tue, 28 Nov 2023 14:08:52 -0500
Subject: [PATCH 2/3] simplify 0D case
---
.../Dialect/MemRef/Transforms/EmulateNarrowType.cpp | 12 +-----------
mlir/test/Dialect/MemRef/emulate-narrow-type.mlir | 8 ++------
2 files changed, 3 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index ec94ecdcf62502b..72f683972662fec 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -366,18 +366,8 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
// Special case 0-rank memref stores. We compute the mask at compile time.
if (convertedType.getRank() == 0) {
- // Create mask to clear destination bits
- auto writeMaskValAttr =
- rewriter.getIntegerAttr(dstIntegerType, ~(1 << (srcBits)) - 1);
- Value writeMask = rewriter.create<arith::ConstantOp>(loc, dstIntegerType,
- writeMaskValAttr);
-
- // Clear destination bits
- rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
- writeMask, adaptor.getMemref(),
- ValueRange{});
// Write srcs bits to destination
- rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
+ rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::assign,
extendedInput, adaptor.getMemref(),
ValueRange{});
rewriter.eraseOp(op);
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 4c88fcc3d656355..fd37b7ff0a27130 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -421,16 +421,12 @@ func.func @rank_zero_memref_store(%arg0: i4) -> () {
// CHECK-SAME: %[[ARG0:.+]]: i4
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<i8>
// CHECK: %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i8
-// CHECK: %[[MASK:.+]] = arith.constant -18 : i8
-// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][] : (i8, memref<i8>) -> i8
-// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[EXTUI]], %[[ALLOC]][] : (i8, memref<i8>) -> i8
+// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i8, memref<i8>) -> i8
// CHECK: return
// CHECK32-LABEL: func @rank_zero_memref
// CHECK32-SAME: %[[ARG0:.+]]: i4
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<i32>
// CHECK32: %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i32
-// CHECK32: %[[MASK:.+]] = arith.constant -18 : i32
-// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
-// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[EXTUI]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
+// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
// CHECK32: return
>From bb05eefac4dff47327ae6d9d12582d9035ee1018 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Tue, 28 Nov 2023 14:26:09 -0500
Subject: [PATCH 3/3] fix comment
---
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 72f683972662fec..8236a4c475f17c5 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -364,9 +364,8 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
adaptor.getValue());
- // Special case 0-rank memref stores. We compute the mask at compile time.
+ // Special case 0-rank memref stores. No need for masking.
if (convertedType.getRank() == 0) {
- // Write srcs bits to destination
rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::assign,
extendedInput, adaptor.getMemref(),
ValueRange{});
More information about the Mlir-commits
mailing list