[Mlir-commits] [mlir] [mlir] Add subbyte emulation support for `memref.store`. (PR #72004)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 10 16:24:06 PST 2023
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {clang-format}-->
:warning: C/C++ code formatter, clang-format found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
git-clang-format --diff 64f62de96609dc3ea9a8a914a9e9445b7f4d625d 5e4abca5feb9c74b9105d9d907064992bb0d2fc5 -- mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
``````````
</details>
<details>
<summary>
View the diff from clang-format here.
</summary>
``````````diff
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 44070e789d9e..7bed0162c3b8 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -51,14 +51,15 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
}
static Value getRightOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
- int sourceBits, int targetBits,
- OpBuilder &builder) {
+ int sourceBits, int targetBits,
+ OpBuilder &builder) {
assert(targetBits % sourceBits == 0);
AffineExpr s0;
bindSymbols(builder.getContext(), s0);
int scaleFactor = targetBits / sourceBits;
OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply(
- builder, loc, (scaleFactor - 1 - s0 % scaleFactor) * sourceBits, {srcIdx});
+ builder, loc, (scaleFactor - 1 - s0 % scaleFactor) * sourceBits,
+ {srcIdx});
Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
IntegerType dstType = builder.getIntegerType(targetBits);
return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
@@ -248,20 +249,34 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
Location loc = op.getLoc();
// Special case 0-rank memref loads.
auto dstIntegerType = rewriter.getIntegerType(dstBits);
- Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType, adaptor.getValue());
+ Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
+ adaptor.getValue());
if (convertedType.getRank() == 0) {
// Shift extended value to be left aligned
- auto shiftValAttr = rewriter.getIntegerAttr(dstIntegerType, dstBits - srcBits);
- Value shiftVal = rewriter.create<arith::ConstantOp>(loc, dstIntegerType, shiftValAttr).getResult();
- Value alignedVal = rewriter.create<arith::ShLIOp>(loc, extendedInput, shiftVal).getResult();
+ auto shiftValAttr =
+ rewriter.getIntegerAttr(dstIntegerType, dstBits - srcBits);
+ Value shiftVal =
+ rewriter.create<arith::ConstantOp>(loc, dstIntegerType, shiftValAttr)
+ .getResult();
+ Value alignedVal =
+ rewriter.create<arith::ShLIOp>(loc, extendedInput, shiftVal)
+ .getResult();
// Create mask to clear destination bits
- auto writeMaskValAttr = rewriter.getIntegerAttr(dstIntegerType, 1 << (dstBits - srcBits) - 1);
- Value writeMask = rewriter.create<arith::ConstantOp>(loc, dstIntegerType, writeMaskValAttr).getResult();
+ auto writeMaskValAttr =
+ rewriter.getIntegerAttr(dstIntegerType, 1 << (dstBits - srcBits) - 1);
+ Value writeMask =
+ rewriter
+ .create<arith::ConstantOp>(loc, dstIntegerType, writeMaskValAttr)
+ .getResult();
// Clear destination bits
- rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi, writeMask, adaptor.getMemref(), ValueRange{});
+ rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
+ writeMask, adaptor.getMemref(),
+ ValueRange{});
// Write srcs bits to destination
- rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori, alignedVal, adaptor.getMemref(), ValueRange{});
+ rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
+ alignedVal, adaptor.getMemref(),
+ ValueRange{});
} else {
SmallVector<OpFoldResult> indices =
getAsOpFoldResult(adaptor.getIndices());
@@ -285,25 +300,42 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
OpFoldResult scaledLinearizedIndices =
affine::makeComposedFoldedAffineApply(
rewriter, loc, s0.floorDiv(scaler), {linearizedIndices});
- Value storeIndices = getValueOrCreateConstantIndexOp(rewriter, loc, scaledLinearizedIndices);
+ Value storeIndices = getValueOrCreateConstantIndexOp(
+ rewriter, loc, scaledLinearizedIndices);
// Create mask to clear destination bits
- Value bitwidthOffset = getRightOffsetForBitwidth(loc, linearizedIndices,
- srcBits, dstBits, rewriter);
- auto maskRightAlignedAttr = rewriter.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
- Value maskRightAligned = rewriter.create<arith::ConstantOp>(loc, dstIntegerType, maskRightAlignedAttr).getResult();
- Value writeMaskInverse = rewriter.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
- // Perform logical NOT on the `writeMaskInverse` to get a mask that clears the destination bits
+ Value bitwidthOffset = getRightOffsetForBitwidth(
+ loc, linearizedIndices, srcBits, dstBits, rewriter);
+ auto maskRightAlignedAttr =
+ rewriter.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
+ Value maskRightAligned =
+ rewriter
+ .create<arith::ConstantOp>(loc, dstIntegerType,
+ maskRightAlignedAttr)
+ .getResult();
+ Value writeMaskInverse =
+ rewriter.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
+ // Perform logical NOT on the `writeMaskInverse` to get a mask that clears
+ // the destination bits
auto flipValAttr = rewriter.getIntegerAttr(dstIntegerType, -1);
- Value flipVal = rewriter.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr).getResult();
- Value writeMask = rewriter.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
+ Value flipVal =
+ rewriter.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr)
+ .getResult();
+ Value writeMask =
+ rewriter.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
// Align the value to write with the destination bits
- Value alignedVal = rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset).getResult();
+ Value alignedVal =
+ rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset)
+ .getResult();
// Clear destination bits
- rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi, writeMask, adaptor.getMemref(), storeIndices);
+ rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
+ writeMask, adaptor.getMemref(),
+ storeIndices);
// Write srcs bits to destination
- rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori, alignedVal, adaptor.getMemref(), storeIndices);
+ rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
+ alignedVal, adaptor.getMemref(),
+ storeIndices);
}
rewriter.eraseOp(op);
@@ -391,9 +423,10 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
RewritePatternSet &patterns) {
// Populate `memref.*` conversion patterns.
- patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
- ConvertMemRefAssumeAlignment, ConvertMemRefSubview, ConvertMemrefStore>(
- typeConverter, patterns.getContext());
+ patterns
+ .add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment,
+ ConvertMemRefSubview, ConvertMemrefStore>(typeConverter,
+ patterns.getContext());
memref::populateResolveExtractStridedMetadataPatterns(patterns);
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/72004
More information about the Mlir-commits
mailing list