[Mlir-commits] [mlir] [mlir] Add narrow type emulation conversions (PR #72181)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 13 16:50:26 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref
Author: None (Max191)
<details>
<summary>Changes</summary>
Adds narrow type emulation support for:
- `memref.alloca`
- `memref.store`
- `memref.reinterpret_cast`
---
Patch is 30.02 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/72181.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp (+218-40)
- (modified) mlir/test/Dialect/MemRef/emulate-narrow-type.mlir (+228)
``````````diff
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 9f58e9055acadbb..078df55e351db96 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"
@@ -35,36 +36,98 @@ using namespace mlir;
/// Return the bit offset of the value at position `srcIdx`. For example, if
/// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
/// located at (x % 2) * 4. Because there are two elements in one i8, and one
-/// element has 4 bits.
+/// element has 4 bits. If `rightOffset` is true, return the offset from the
+/// right side of the `dstBits` container instead of the left side.
static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
int sourceBits, int targetBits,
- OpBuilder &builder) {
+ OpBuilder &builder,
+ bool rightOffset = false) {
assert(targetBits % sourceBits == 0);
AffineExpr s0;
bindSymbols(builder.getContext(), s0);
int scaleFactor = targetBits / sourceBits;
- OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply(
- builder, loc, (s0 % scaleFactor) * sourceBits, {srcIdx});
+ AffineExpr offsetExpr =
+ rightOffset ? (scaleFactor - 1 - s0 % scaleFactor) * sourceBits
+ : (s0 % scaleFactor) * sourceBits;
+ OpFoldResult offsetVal =
+ affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx});
Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
IntegerType dstType = builder.getIntegerType(targetBits);
return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
}
+/// When writing a subbyte size, writing needs to happen atomically in case of
+/// another write happening on the same byte at the same time. To do the write,
+/// we first must clear `dstBits` at the `linearizedIndices` of the subbyte
+/// store. This function returns the appropriate mask for clearing these bits.
+static Value getAtomicWriteMask(Location loc, OpFoldResult linearizedIndices,
+ int64_t srcBits, int64_t dstBits,
+ Value bitwidthOffset, OpBuilder &builder) {
+ auto dstIntegerType = builder.getIntegerType(dstBits);
+ auto maskRightAlignedAttr =
+ builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
+ Value maskRightAligned =
+ builder
+ .create<arith::ConstantOp>(loc, dstIntegerType, maskRightAlignedAttr)
+ .getResult();
+ Value writeMaskInverse =
+ builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
+ auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
+ Value flipVal =
+ builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr)
+ .getResult();
+ return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
+}
+
+/// Returns the scaled linearized index based on the `srcBits` and `dstBits`
+/// sizes. The input `linearizedIndex` has the grandularity of `srcBits`, and
+/// the returned index has the granularity of `dstBits`
+static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc,
+ OpFoldResult linearizedIndex,
+ int64_t srcBits, int64_t dstBits) {
+ AffineExpr s0;
+ bindSymbols(builder.getContext(), s0);
+ int64_t scaler = dstBits / srcBits;
+ OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply(
+ builder, loc, s0.floorDiv(scaler), {linearizedIndex});
+ return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices);
+}
+
+static OpFoldResult
+getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
+ const SmallVector<OpFoldResult> &indices,
+ Value memref) {
+ auto stridedMetadata =
+ builder.create<memref::ExtractStridedMetadataOp>(loc, memref);
+ OpFoldResult linearizedIndices;
+ std::tie(std::ignore, linearizedIndices) =
+ memref::getLinearizedMemRefOffsetAndSize(
+ builder, loc, srcBits, srcBits,
+ stridedMetadata.getConstifiedMixedOffset(),
+ stridedMetadata.getConstifiedMixedSizes(),
+ stridedMetadata.getConstifiedMixedStrides(), indices);
+ return linearizedIndices;
+}
+
namespace {
//===----------------------------------------------------------------------===//
// ConvertMemRefAlloc
//===----------------------------------------------------------------------===//
-struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
- using OpConversionPattern::OpConversionPattern;
+template <typename OpTy>
+struct ConvertMemRefAlloc final : OpConversionPattern<OpTy> {
+ using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult
- matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
+ matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto currentType = op.getMemref().getType().cast<MemRefType>();
- auto newResultType =
- getTypeConverter()->convertType(op.getType()).dyn_cast<MemRefType>();
+ static_assert(std::is_same<OpTy, memref::AllocOp>() ||
+ std::is_same<OpTy, memref::AllocaOp>(),
+ "expected only memref::AllocOp or memref::AllocaOp");
+ auto currentType = cast<MemRefType>(op.getMemref().getType());
+ auto newResultType = dyn_cast<MemRefType>(
+ this->getTypeConverter()->convertType(op.getType()));
if (!newResultType) {
return rewriter.notifyMatchFailure(
op->getLoc(),
@@ -73,9 +136,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
// Special case zero-rank memrefs.
if (currentType.getRank() == 0) {
- rewriter.replaceOpWithNewOp<memref::AllocOp>(
- op, newResultType, ValueRange{}, adaptor.getSymbolOperands(),
- adaptor.getAlignmentAttr());
+ rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
+ adaptor.getSymbolOperands(),
+ adaptor.getAlignmentAttr());
return success();
}
@@ -97,9 +160,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
rewriter, loc, linearizedMemRefInfo.linearizedSize));
}
- rewriter.replaceOpWithNewOp<memref::AllocOp>(
- op, newResultType, dynamicLinearizedSize, adaptor.getSymbolOperands(),
- adaptor.getAlignmentAttr());
+ rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
+ adaptor.getSymbolOperands(),
+ adaptor.getAlignmentAttr());
return success();
}
};
@@ -155,32 +218,15 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
ValueRange{});
} else {
- SmallVector<OpFoldResult> indices =
- getAsOpFoldResult(adaptor.getIndices());
-
- auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
- loc, op.getMemRef());
-
// Linearize the indices of the original load instruction. Do not account
// for the scaling yet. This will be accounted for later.
- OpFoldResult linearizedIndices;
- std::tie(std::ignore, linearizedIndices) =
- memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, srcBits,
- stridedMetadata.getConstifiedMixedOffset(),
- stridedMetadata.getConstifiedMixedSizes(),
- stridedMetadata.getConstifiedMixedStrides(), indices);
-
- AffineExpr s0;
- bindSymbols(rewriter.getContext(), s0);
- int64_t scaler = dstBits / srcBits;
- OpFoldResult scaledLinearizedIndices =
- affine::makeComposedFoldedAffineApply(
- rewriter, loc, s0.floorDiv(scaler), {linearizedIndices});
+ OpFoldResult linearizedIndices = getLinearizedSrcIndices(
+ rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
+
Value newLoad = rewriter.create<memref::LoadOp>(
loc, adaptor.getMemref(),
- getValueOrCreateConstantIndexOp(rewriter, loc,
- scaledLinearizedIndices));
+ getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
+ dstBits));
// Get the offset and shift the bits to the rightmost.
// Note, currently only the big-endian is supported.
@@ -211,6 +257,136 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertMemRefReinterpretCast
+//===----------------------------------------------------------------------===//
+
+///
+struct ConvertMemRefReinterpretCast final
+ : OpConversionPattern<memref::ReinterpretCastOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ MemRefType newTy =
+ dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+ if (!newTy) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(),
+ llvm::formatv("failed to convert memref type: {0}", op.getType()));
+ }
+
+ auto convertedElementType = newTy.getElementType();
+ auto oldElementType = op.getType().getElementType();
+ int srcBits = oldElementType.getIntOrFloatBitWidth();
+ int dstBits = convertedElementType.getIntOrFloatBitWidth();
+ if (dstBits % srcBits != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "only dstBits % srcBits == 0 supported");
+ }
+
+ // Only support offset for 0-D subview.
+ if (op.getType().getRank() != 0) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), "subview with rank > 0 is not supported");
+ }
+
+ int64_t offset = op.getStaticOffset(0);
+ // Only support static sizes and offsets.
+ if (offset == ShapedType::kDynamic) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), "subview with dynamic offset is not supported");
+ }
+
+ int elementsPerByte = dstBits / srcBits;
+ if (offset % elementsPerByte != 0) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(),
+ "subview with offset not multiple of elementsPerByte is not "
+ "supported");
+ }
+
+ offset = offset / elementsPerByte;
+
+ rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+ op, newTy, *adaptor.getODSOperands(0).begin(), offset,
+ SmallVector<int64_t>{}, op.getStaticStrides());
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertMemrefStore
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
+ auto convertedElementType = convertedType.getElementType();
+ auto oldElementType = op.getMemRefType().getElementType();
+ int srcBits = oldElementType.getIntOrFloatBitWidth();
+ int dstBits = convertedElementType.getIntOrFloatBitWidth();
+ auto dstIntegerType = rewriter.getIntegerType(dstBits);
+ if (dstBits % srcBits != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "only dstBits % srcBits == 0 supported");
+ }
+
+ Location loc = op.getLoc();
+ Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType, adaptor.getValue());
+
+ // Special case 0-rank memref stores. We can compute the mask at compile
+ // time.
+ if (convertedType.getRank() == 0) {
+ // Shift extended value to be left aligned
+ auto shiftValAttr = rewriter.getIntegerAttr(dstIntegerType, dstBits - srcBits);
+ Value shiftVal = rewriter.create<arith::ConstantOp>(loc, dstIntegerType, shiftValAttr).getResult();
+ Value alignedVal = rewriter.create<arith::ShLIOp>(loc, extendedInput, shiftVal).getResult();
+ // Create mask to clear destination bits
+ auto writeMaskValAttr = rewriter.getIntegerAttr(
+ dstIntegerType, (1 << (dstBits - srcBits)) - 1);
+ Value writeMask = rewriter.create<arith::ConstantOp>(loc, dstIntegerType, writeMaskValAttr).getResult();
+
+ // Clear destination bits
+ rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi, writeMask, adaptor.getMemref(), ValueRange{});
+ // Write srcs bits to destination
+ rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori, alignedVal, adaptor.getMemref(), ValueRange{});
+ rewriter.eraseOp(op);
+ return success();
+ }
+
+ OpFoldResult linearizedIndices = getLinearizedSrcIndices(
+ rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
+ Value storeIndices = getIndicesForLoadOrStore(
+ rewriter, loc, linearizedIndices, srcBits, dstBits);
+ Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits,
+ dstBits, rewriter, true);
+ Value writeMask = getAtomicWriteMask(loc, linearizedIndices, srcBits,
+ dstBits, bitwidthOffset, rewriter);
+ // Align the value to write with the destination bits
+ Value alignedVal =
+ rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset)
+ .getResult();
+
+ // Clear destination bits
+ rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
+ writeMask, adaptor.getMemref(),
+ storeIndices);
+ // Write srcs bits to destination
+ rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
+ alignedVal, adaptor.getMemref(),
+ storeIndices);
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertMemRefSubview
//===----------------------------------------------------------------------===//
@@ -291,8 +467,10 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
RewritePatternSet &patterns) {
// Populate `memref.*` conversion patterns.
- patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
- ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
+ patterns.add<ConvertMemRefAlloc<memref::AllocOp>,
+ ConvertMemRefAlloc<memref::AllocaOp>, ConvertMemRefLoad,
+ ConvertMemRefAssumeAlignment, ConvertMemRefSubview,
+ ConvertMemrefStore, ConvertMemRefReinterpretCast>(
typeConverter, patterns.getContext());
memref::populateResolveExtractStridedMetadataPatterns(patterns);
}
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 6ed97f05aa7cff2..05ec5761c8fe024 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -174,3 +174,231 @@ func.func @memref_strided_i4(%idx : index) -> i4 {
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<16xi32>
// CHECK32: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][4] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: 4>>
// CHECK32: %[[LOAD:.+]] = memref.load %[[SUBVIEW]]
+
+// -----
+
+func.func @memref_alloca_load_i4(%arg0: index) -> i4 {
+ %0 = memref.alloca() : memref<5xi4>
+ %1 = memref.load %0[%arg0] : memref<5xi4>
+ return %1 : i4
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)
+// CHECK: func @memref_alloca_load_i4(
+// CHECK-SAME: %[[ARG0:.+]]: index
+// CHECK: %[[ALLOCA:.+]] = memref.alloca() : memref<3xi8>
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK: %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
+// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
+// CHECK: return %[[TRUNC]]
+
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)
+// CHECK32: func @memref_alloca_load_i4(
+// CHECK32-SAME: %[[ARG0:.+]]: index
+// CHECK32: %[[ALLOCA:.+]] = memref.alloca() : memref<1xi32>
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK32: %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
+// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
+// CHECK32: return %[[TRUNC]]
+
+// -----
+
+func.func @memref_store_i4(%arg0: index, %arg1: i4) -> () {
+ %0 = memref.alloc() : memref<5xi4>
+ memref.store %arg1, %0[%arg0] : memref<5xi4>
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * -4 + (s0 floordiv 2) * 8 + 4)>
+// CHECK: func @memref_store_i4(
+// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: i4
+// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG1]] : i4 to i8
+// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK-DAG: %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+// CHECK-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i8
+// CHECK-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
+// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
+// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
+// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %alloc[%[[INDEX]]] : (i8, memref<3xi8>) -> i8
+// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %alloc[%[[INDEX]]] : (i8, memref<3xi8>) -> i8
+// CHECK: return
+
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * -4 + (s0 floordiv 8) * 32 + 28)>
+// CHECK32: func @memref_store_i4(
+// CHECK32-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: i4
+// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<1xi32>
+// CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG1]] : i4 to i32
+// CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK32-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK32-DAG: %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+// CHECK32-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i32
+// CHECK32-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/72181
More information about the Mlir-commits
mailing list