[Mlir-commits] [mlir] [mlir] Add subbyte emulation support for `memref.store`. (PR #72004)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 10 16:24:06 PST 2023


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {clang-format}-->


:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff 64f62de96609dc3ea9a8a914a9e9445b7f4d625d 5e4abca5feb9c74b9105d9d907064992bb0d2fc5 -- mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
``````````

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 44070e789d9e..7bed0162c3b8 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -51,14 +51,15 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
 }
 
 static Value getRightOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
-                                  int sourceBits, int targetBits,
-                                  OpBuilder &builder) {
+                                       int sourceBits, int targetBits,
+                                       OpBuilder &builder) {
   assert(targetBits % sourceBits == 0);
   AffineExpr s0;
   bindSymbols(builder.getContext(), s0);
   int scaleFactor = targetBits / sourceBits;
   OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply(
-      builder, loc, (scaleFactor - 1 - s0 % scaleFactor) * sourceBits, {srcIdx});
+      builder, loc, (scaleFactor - 1 - s0 % scaleFactor) * sourceBits,
+      {srcIdx});
   Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
   IntegerType dstType = builder.getIntegerType(targetBits);
   return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
@@ -248,20 +249,34 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
     Location loc = op.getLoc();
     // Special case 0-rank memref loads.
     auto dstIntegerType = rewriter.getIntegerType(dstBits);
-    Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType, adaptor.getValue());
+    Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
+                                                          adaptor.getValue());
     if (convertedType.getRank() == 0) {
       // Shift extended value to be left aligned
-      auto shiftValAttr = rewriter.getIntegerAttr(dstIntegerType, dstBits - srcBits);
-      Value shiftVal = rewriter.create<arith::ConstantOp>(loc, dstIntegerType, shiftValAttr).getResult();
-      Value alignedVal = rewriter.create<arith::ShLIOp>(loc, extendedInput, shiftVal).getResult();
+      auto shiftValAttr =
+          rewriter.getIntegerAttr(dstIntegerType, dstBits - srcBits);
+      Value shiftVal =
+          rewriter.create<arith::ConstantOp>(loc, dstIntegerType, shiftValAttr)
+              .getResult();
+      Value alignedVal =
+          rewriter.create<arith::ShLIOp>(loc, extendedInput, shiftVal)
+              .getResult();
       // Create mask to clear destination bits
-      auto writeMaskValAttr = rewriter.getIntegerAttr(dstIntegerType, 1 << (dstBits - srcBits) - 1);
-      Value writeMask = rewriter.create<arith::ConstantOp>(loc, dstIntegerType, writeMaskValAttr).getResult();
+      auto writeMaskValAttr =
+          rewriter.getIntegerAttr(dstIntegerType, 1 << (dstBits - srcBits) - 1);
+      Value writeMask =
+          rewriter
+              .create<arith::ConstantOp>(loc, dstIntegerType, writeMaskValAttr)
+              .getResult();
 
       // Clear destination bits
-      rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi, writeMask, adaptor.getMemref(), ValueRange{});
+      rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
+                                           writeMask, adaptor.getMemref(),
+                                           ValueRange{});
       // Write srcs bits to destination
-      rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori, alignedVal, adaptor.getMemref(), ValueRange{});
+      rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
+                                           alignedVal, adaptor.getMemref(),
+                                           ValueRange{});
     } else {
       SmallVector<OpFoldResult> indices =
           getAsOpFoldResult(adaptor.getIndices());
@@ -285,25 +300,42 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
       OpFoldResult scaledLinearizedIndices =
           affine::makeComposedFoldedAffineApply(
               rewriter, loc, s0.floorDiv(scaler), {linearizedIndices});
-      Value storeIndices = getValueOrCreateConstantIndexOp(rewriter, loc, scaledLinearizedIndices);
+      Value storeIndices = getValueOrCreateConstantIndexOp(
+          rewriter, loc, scaledLinearizedIndices);
 
       // Create mask to clear destination bits
-      Value bitwidthOffset = getRightOffsetForBitwidth(loc, linearizedIndices,
-                                                  srcBits, dstBits, rewriter);
-      auto maskRightAlignedAttr = rewriter.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
-      Value maskRightAligned = rewriter.create<arith::ConstantOp>(loc, dstIntegerType, maskRightAlignedAttr).getResult();
-      Value writeMaskInverse = rewriter.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
-      // Perform logical NOT on the `writeMaskInverse` to get a mask that clears the destination bits
+      Value bitwidthOffset = getRightOffsetForBitwidth(
+          loc, linearizedIndices, srcBits, dstBits, rewriter);
+      auto maskRightAlignedAttr =
+          rewriter.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
+      Value maskRightAligned =
+          rewriter
+              .create<arith::ConstantOp>(loc, dstIntegerType,
+                                         maskRightAlignedAttr)
+              .getResult();
+      Value writeMaskInverse =
+          rewriter.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
+      // Perform logical NOT on the `writeMaskInverse` to get a mask that clears
+      // the destination bits
       auto flipValAttr = rewriter.getIntegerAttr(dstIntegerType, -1);
-      Value flipVal = rewriter.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr).getResult();
-      Value writeMask = rewriter.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
+      Value flipVal =
+          rewriter.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr)
+              .getResult();
+      Value writeMask =
+          rewriter.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
       // Align the value to write with the destination bits
-      Value alignedVal = rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset).getResult();
+      Value alignedVal =
+          rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset)
+              .getResult();
 
       // Clear destination bits
-      rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi, writeMask, adaptor.getMemref(), storeIndices);
+      rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
+                                           writeMask, adaptor.getMemref(),
+                                           storeIndices);
       // Write srcs bits to destination
-      rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori, alignedVal, adaptor.getMemref(), storeIndices);
+      rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
+                                           alignedVal, adaptor.getMemref(),
+                                           storeIndices);
     }
 
     rewriter.eraseOp(op);
@@ -391,9 +423,10 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
     RewritePatternSet &patterns) {
 
   // Populate `memref.*` conversion patterns.
-  patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
-               ConvertMemRefAssumeAlignment, ConvertMemRefSubview, ConvertMemrefStore>(
-      typeConverter, patterns.getContext());
+  patterns
+      .add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment,
+           ConvertMemRefSubview, ConvertMemrefStore>(typeConverter,
+                                                     patterns.getContext());
   memref::populateResolveExtractStridedMetadataPatterns(patterns);
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/72004


More information about the Mlir-commits mailing list