[Mlir-commits] [mlir] [mlir] Add subbyte emulation support for `memref.store`. (PR #72004)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 10 16:12:56 PST 2023


https://github.com/Max191 created https://github.com/llvm/llvm-project/pull/72004

None

>From 5e4abca5feb9c74b9105d9d907064992bb0d2fc5 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 10 Nov 2023 17:06:43 -0500
Subject: [PATCH] [mlir] Add subbyte emulation support for `memref.store`.

---
 .../MemRef/Transforms/EmulateNarrowType.cpp   | 102 +++++++++++++++++-
 1 file changed, 101 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 9f58e9055acadbb..44070e789d9ed0e 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -50,6 +50,20 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
   return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
 }
 
+static Value getRightOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
+                                  int sourceBits, int targetBits,
+                                  OpBuilder &builder) {
+  assert(targetBits % sourceBits == 0);
+  AffineExpr s0;
+  bindSymbols(builder.getContext(), s0);
+  int scaleFactor = targetBits / sourceBits;
+  OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply(
+      builder, loc, (scaleFactor - 1 - s0 % scaleFactor) * sourceBits, {srcIdx});
+  Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
+  IntegerType dstType = builder.getIntegerType(targetBits);
+  return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
+}
+
 namespace {
 
 //===----------------------------------------------------------------------===//
@@ -211,6 +225,92 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ConvertMemrefStore
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
+    auto convertedElementType = convertedType.getElementType();
+    auto oldElementType = op.getMemRefType().getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = convertedElementType.getIntOrFloatBitWidth();
+    if (dstBits % srcBits != 0) {
+      return rewriter.notifyMatchFailure(
+          op, "only dstBits % srcBits == 0 supported");
+    }
+
+    Location loc = op.getLoc();
+    // Special case 0-rank memref loads.
+    auto dstIntegerType = rewriter.getIntegerType(dstBits);
+    Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType, adaptor.getValue());
+    if (convertedType.getRank() == 0) {
+      // Shift extended value to be left aligned
+      auto shiftValAttr = rewriter.getIntegerAttr(dstIntegerType, dstBits - srcBits);
+      Value shiftVal = rewriter.create<arith::ConstantOp>(loc, dstIntegerType, shiftValAttr).getResult();
+      Value alignedVal = rewriter.create<arith::ShLIOp>(loc, extendedInput, shiftVal).getResult();
+      // Create mask to clear destination bits
+      auto writeMaskValAttr = rewriter.getIntegerAttr(dstIntegerType, 1 << (dstBits - srcBits) - 1);
+      Value writeMask = rewriter.create<arith::ConstantOp>(loc, dstIntegerType, writeMaskValAttr).getResult();
+
+      // Clear destination bits
+      rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi, writeMask, adaptor.getMemref(), ValueRange{});
+      // Write srcs bits to destination
+      rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori, alignedVal, adaptor.getMemref(), ValueRange{});
+    } else {
+      SmallVector<OpFoldResult> indices =
+          getAsOpFoldResult(adaptor.getIndices());
+
+      auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+          loc, op.getMemRef());
+
+      // Linearize the indices of the original write instruction. Do not account
+      // for the scaling yet. This will be accounted for later.
+      OpFoldResult linearizedIndices;
+      std::tie(std::ignore, linearizedIndices) =
+          memref::getLinearizedMemRefOffsetAndSize(
+              rewriter, loc, srcBits, srcBits,
+              stridedMetadata.getConstifiedMixedOffset(),
+              stridedMetadata.getConstifiedMixedSizes(),
+              stridedMetadata.getConstifiedMixedStrides(), indices);
+
+      AffineExpr s0;
+      bindSymbols(rewriter.getContext(), s0);
+      int64_t scaler = dstBits / srcBits;
+      OpFoldResult scaledLinearizedIndices =
+          affine::makeComposedFoldedAffineApply(
+              rewriter, loc, s0.floorDiv(scaler), {linearizedIndices});
+      Value storeIndices = getValueOrCreateConstantIndexOp(rewriter, loc, scaledLinearizedIndices);
+
+      // Create mask to clear destination bits
+      Value bitwidthOffset = getRightOffsetForBitwidth(loc, linearizedIndices,
+                                                  srcBits, dstBits, rewriter);
+      auto maskRightAlignedAttr = rewriter.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
+      Value maskRightAligned = rewriter.create<arith::ConstantOp>(loc, dstIntegerType, maskRightAlignedAttr).getResult();
+      Value writeMaskInverse = rewriter.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
+      // Perform logical NOT on the `writeMaskInverse` to get a mask that clears the destination bits
+      auto flipValAttr = rewriter.getIntegerAttr(dstIntegerType, -1);
+      Value flipVal = rewriter.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr).getResult();
+      Value writeMask = rewriter.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
+      // Align the value to write with the destination bits
+      Value alignedVal = rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset).getResult();
+
+      // Clear destination bits
+      rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi, writeMask, adaptor.getMemref(), storeIndices);
+      // Write srcs bits to destination
+      rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori, alignedVal, adaptor.getMemref(), storeIndices);
+    }
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ConvertMemRefSubview
 //===----------------------------------------------------------------------===//
@@ -292,7 +392,7 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
 
   // Populate `memref.*` conversion patterns.
   patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
-               ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
+               ConvertMemRefAssumeAlignment, ConvertMemRefSubview, ConvertMemrefStore>(
       typeConverter, patterns.getContext());
   memref::populateResolveExtractStridedMetadataPatterns(patterns);
 }



More information about the Mlir-commits mailing list