[Mlir-commits] [mlir] [mlir] Add subbyte emulation support for `memref.store`. (PR #72004)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 10 16:12:56 PST 2023
https://github.com/Max191 created https://github.com/llvm/llvm-project/pull/72004
None
>From 5e4abca5feb9c74b9105d9d907064992bb0d2fc5 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 10 Nov 2023 17:06:43 -0500
Subject: [PATCH] [mlir] Add subbyte emulation support for `memref.store`.
---
.../MemRef/Transforms/EmulateNarrowType.cpp | 102 +++++++++++++++++-
1 file changed, 101 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 9f58e9055acadbb..44070e789d9ed0e 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -50,6 +50,20 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
}
+static Value getRightOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
+ int sourceBits, int targetBits,
+ OpBuilder &builder) {
+ assert(targetBits % sourceBits == 0);
+ AffineExpr s0;
+ bindSymbols(builder.getContext(), s0);
+ int scaleFactor = targetBits / sourceBits;
+ OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply(
+ builder, loc, (scaleFactor - 1 - s0 % scaleFactor) * sourceBits, {srcIdx});
+ Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
+ IntegerType dstType = builder.getIntegerType(targetBits);
+ return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
+}
+
namespace {
//===----------------------------------------------------------------------===//
@@ -211,6 +225,92 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertMemrefStore
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
+ auto convertedElementType = convertedType.getElementType();
+ auto oldElementType = op.getMemRefType().getElementType();
+ int srcBits = oldElementType.getIntOrFloatBitWidth();
+ int dstBits = convertedElementType.getIntOrFloatBitWidth();
+ if (dstBits % srcBits != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "only dstBits % srcBits == 0 supported");
+ }
+
+ Location loc = op.getLoc();
+ // Special case 0-rank memref loads.
+ auto dstIntegerType = rewriter.getIntegerType(dstBits);
+ Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType, adaptor.getValue());
+ if (convertedType.getRank() == 0) {
+ // Shift extended value to be left aligned
+ auto shiftValAttr = rewriter.getIntegerAttr(dstIntegerType, dstBits - srcBits);
+ Value shiftVal = rewriter.create<arith::ConstantOp>(loc, dstIntegerType, shiftValAttr).getResult();
+ Value alignedVal = rewriter.create<arith::ShLIOp>(loc, extendedInput, shiftVal).getResult();
+ // Create mask to clear destination bits
+ auto writeMaskValAttr = rewriter.getIntegerAttr(dstIntegerType, 1 << (dstBits - srcBits) - 1);
+ Value writeMask = rewriter.create<arith::ConstantOp>(loc, dstIntegerType, writeMaskValAttr).getResult();
+
+ // Clear destination bits
+ rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi, writeMask, adaptor.getMemref(), ValueRange{});
+ // Write srcs bits to destination
+ rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori, alignedVal, adaptor.getMemref(), ValueRange{});
+ } else {
+ SmallVector<OpFoldResult> indices =
+ getAsOpFoldResult(adaptor.getIndices());
+
+ auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+ loc, op.getMemRef());
+
+ // Linearize the indices of the original write instruction. Do not account
+ // for the scaling yet. This will be accounted for later.
+ OpFoldResult linearizedIndices;
+ std::tie(std::ignore, linearizedIndices) =
+ memref::getLinearizedMemRefOffsetAndSize(
+ rewriter, loc, srcBits, srcBits,
+ stridedMetadata.getConstifiedMixedOffset(),
+ stridedMetadata.getConstifiedMixedSizes(),
+ stridedMetadata.getConstifiedMixedStrides(), indices);
+
+ AffineExpr s0;
+ bindSymbols(rewriter.getContext(), s0);
+ int64_t scaler = dstBits / srcBits;
+ OpFoldResult scaledLinearizedIndices =
+ affine::makeComposedFoldedAffineApply(
+ rewriter, loc, s0.floorDiv(scaler), {linearizedIndices});
+ Value storeIndices = getValueOrCreateConstantIndexOp(rewriter, loc, scaledLinearizedIndices);
+
+ // Create mask to clear destination bits
+ Value bitwidthOffset = getRightOffsetForBitwidth(loc, linearizedIndices,
+ srcBits, dstBits, rewriter);
+ auto maskRightAlignedAttr = rewriter.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
+ Value maskRightAligned = rewriter.create<arith::ConstantOp>(loc, dstIntegerType, maskRightAlignedAttr).getResult();
+ Value writeMaskInverse = rewriter.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
+ // Perform logical NOT on the `writeMaskInverse` to get a mask that clears the destination bits
+ auto flipValAttr = rewriter.getIntegerAttr(dstIntegerType, -1);
+ Value flipVal = rewriter.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr).getResult();
+ Value writeMask = rewriter.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
+ // Align the value to write with the destination bits
+ Value alignedVal = rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset).getResult();
+
+ // Clear destination bits
+ rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi, writeMask, adaptor.getMemref(), storeIndices);
+ // Write srcs bits to destination
+ rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori, alignedVal, adaptor.getMemref(), storeIndices);
+ }
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertMemRefSubview
//===----------------------------------------------------------------------===//
@@ -292,7 +392,7 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
// Populate `memref.*` conversion patterns.
patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
- ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
+ ConvertMemRefAssumeAlignment, ConvertMemRefSubview, ConvertMemrefStore>(
typeConverter, patterns.getContext());
memref::populateResolveExtractStridedMetadataPatterns(patterns);
}
More information about the Mlir-commits
mailing list