[Mlir-commits] [mlir] [mlir][MemRef] Extend memref.subview sub-byte type emulation support. (PR #94045)
Han-Chung Wang
llvmlistbot at llvm.org
Mon Jun 3 14:53:58 PDT 2024
https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/94045
>From cb71135fc2747925d998d2b9a0ffe6ee8f7c7a6f Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Fri, 31 May 2024 13:27:56 -0700
Subject: [PATCH 1/3] [mlir][MemRef] Extend memref.subview sub-byte type
emulation support.
In some cases (see https://github.com/iree-org/iree/issues/16285), `memref.subview` ops can't be folded into transfer ops and sub-byte type emulation fails. This issue has been blocking a few things, including the enablement of vector flattening transformations (https://github.com/iree-org/iree/pull/16456). This PR extends the existing sub-byte type emulation support of `memref.subview` to handle multi-dimensional subviews with dynamic offsets and addresses the issues for some of the `memref.subview` cases that can't be folded.
---
.../MemRef/Transforms/EmulateNarrowType.cpp | 160 ++++++++++--------
mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp | 4 +-
.../Dialect/MemRef/emulate-narrow-type.mlir | 23 +++
3 files changed, 116 insertions(+), 71 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 77c108aab4807..3510cc89cc358 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -32,62 +32,6 @@ using namespace mlir;
// Utility functions
//===----------------------------------------------------------------------===//
-/// Converts a memref::SubViewOp or memref::ReinterpretCastOp to the converted
-/// type. The result MemRefType of the old op must have a rank and stride of 1,
-/// with static offset and size. The number of bits in the offset must evenly
-/// divide the bitwidth of the new converted type.
-template <typename MemRefOpTy>
-static LogicalResult convertCastingOp(ConversionPatternRewriter &rewriter,
- typename MemRefOpTy::Adaptor adaptor,
- MemRefOpTy op, MemRefType newTy) {
- static_assert(std::is_same<MemRefOpTy, memref::SubViewOp>() ||
- std::is_same<MemRefOpTy, memref::ReinterpretCastOp>(),
- "Expected only memref::SubViewOp or memref::ReinterpretCastOp");
-
- auto convertedElementType = newTy.getElementType();
- auto oldElementType = op.getType().getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = convertedElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(op,
- "only dstBits % srcBits == 0 supported");
- }
-
- // Only support stride of 1.
- if (llvm::any_of(op.getStaticStrides(),
- [](int64_t stride) { return stride != 1; })) {
- return rewriter.notifyMatchFailure(op->getLoc(),
- "stride != 1 is not supported");
- }
-
- auto sizes = op.getStaticSizes();
- int64_t offset = op.getStaticOffset(0);
- // Only support static sizes and offsets.
- if (llvm::any_of(sizes,
- [](int64_t size) { return size == ShapedType::kDynamic; }) ||
- offset == ShapedType::kDynamic) {
- return rewriter.notifyMatchFailure(
- op->getLoc(), "dynamic size or offset is not supported");
- }
-
- int elementsPerByte = dstBits / srcBits;
- if (offset % elementsPerByte != 0) {
- return rewriter.notifyMatchFailure(
- op->getLoc(), "offset not multiple of elementsPerByte is not "
- "supported");
- }
-
- SmallVector<int64_t> size;
- if (sizes.size())
- size.push_back(ceilDiv(sizes[0], elementsPerByte));
- offset = offset / elementsPerByte;
-
- rewriter.replaceOpWithNewOp<MemRefOpTy>(op, newTy,
- *adaptor.getODSOperands(0).begin(),
- offset, size, op.getStaticStrides());
- return success();
-}
-
/// When data is loaded/stored in `targetBits` granularity, but is used in
/// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
/// treated as an array of elements of width `sourceBits`.
@@ -335,7 +279,48 @@ struct ConvertMemRefReinterpretCast final
op->getLoc(), "subview with rank > 1 is not supported");
}
- return convertCastingOp(rewriter, adaptor, op, newTy);
+ Type convertedElementType = newTy.getElementType();
+ Type oldElementType = op.getType().getElementType();
+ int srcBits = oldElementType.getIntOrFloatBitWidth();
+ int dstBits = convertedElementType.getIntOrFloatBitWidth();
+ if (dstBits % srcBits != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "only dstBits % srcBits == 0 supported");
+ }
+
+ // Only support stride of 1.
+ if (llvm::any_of(op.getStaticStrides(),
+ [](int64_t stride) { return stride != 1; })) {
+ return rewriter.notifyMatchFailure(op->getLoc(),
+ "stride != 1 is not supported");
+ }
+
+ auto sizes = op.getStaticSizes();
+ int64_t offset = op.getStaticOffset(0);
+ // Only support static sizes and offsets.
+ if (llvm::any_of(
+ sizes, [](int64_t size) { return size == ShapedType::kDynamic; }) ||
+ offset == ShapedType::kDynamic) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), "dynamic size or offset is not supported");
+ }
+
+ int elementsPerByte = dstBits / srcBits;
+ if (offset % elementsPerByte != 0) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), "offset not multiple of elementsPerByte is not "
+ "supported");
+ }
+
+ SmallVector<int64_t> size;
+ if (sizes.size())
+ size.push_back(ceilDiv(sizes[0], elementsPerByte));
+ offset = offset / elementsPerByte;
+
+ rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+ op, newTy, *adaptor.getODSOperands(0).begin(), offset, size,
+ op.getStaticStrides());
+ return success();
}
};
@@ -402,29 +387,68 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
/// Emulating narrow ints on subview have limited support, supporting only
/// static offset and size and stride of 1. Ideally, the subview should be
-/// folded away before running narrow type emulation, and this pattern would
-/// never run. This pattern is mostly used for testing pruposes.
+/// folded away before running narrow type emulation, and this pattern should
+/// only run for cases that can't be folded.
struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
+ matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- MemRefType newTy =
- dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+ MemRefType newTy = dyn_cast<MemRefType>(
+ getTypeConverter()->convertType(subViewOp.getType()));
if (!newTy) {
return rewriter.notifyMatchFailure(
- op->getLoc(),
- llvm::formatv("failed to convert memref type: {0}", op.getType()));
+ subViewOp->getLoc(),
+ llvm::formatv("failed to convert memref type: {0}",
+ subViewOp.getType()));
}
- // Only support offset for 1-D subview.
- if (op.getType().getRank() != 1) {
+ Location loc = subViewOp.getLoc();
+ Type convertedElementType = newTy.getElementType();
+ Type oldElementType = subViewOp.getType().getElementType();
+ int srcBits = oldElementType.getIntOrFloatBitWidth();
+ int dstBits = convertedElementType.getIntOrFloatBitWidth();
+ if (dstBits % srcBits != 0)
return rewriter.notifyMatchFailure(
- op->getLoc(), "subview with rank > 1 is not supported");
+ subViewOp, "only dstBits % srcBits == 0 supported");
+
+ // Only support stride of 1.
+ if (llvm::any_of(subViewOp.getStaticStrides(),
+ [](int64_t stride) { return stride != 1; })) {
+ return rewriter.notifyMatchFailure(subViewOp->getLoc(),
+ "stride != 1 is not supported");
+ }
+
+ auto sizes = subViewOp.getStaticSizes();
+ int64_t lastOffset = subViewOp.getStaticOffsets().back();
+ // Only support static sizes and offsets.
+ if (llvm::any_of(
+ sizes, [](int64_t size) { return size == ShapedType::kDynamic; }) ||
+ lastOffset == ShapedType::kDynamic) {
+ return rewriter.notifyMatchFailure(
+ subViewOp->getLoc(), "dynamic size or offset is not supported");
}
- return convertCastingOp(rewriter, adaptor, op, newTy);
+ // Transform the offsets, sizes and strides according to the emulation.
+ auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+ loc, subViewOp.getViewSource());
+
+ OpFoldResult linearizedIndices;
+ auto strides = stridedMetadata.getConstifiedMixedStrides();
+ memref::LinearizedMemRefInfo linearizedInfo;
+ std::tie(linearizedInfo, linearizedIndices) =
+ memref::getLinearizedMemRefOffsetAndSize(
+ rewriter, loc, srcBits, dstBits,
+ stridedMetadata.getConstifiedMixedOffset(),
+ subViewOp.getMixedSizes(), strides,
+ getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(),
+ rewriter));
+
+ rewriter.replaceOpWithNewOp<memref::SubViewOp>(
+ subViewOp, newTy, adaptor.getSource(), linearizedIndices,
+ linearizedInfo.linearizedSize, strides.back());
+ return success();
}
};
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index 05d5ca2ce12f4..68edd45448ee5 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -68,7 +68,6 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
AffineExpr mulMap = builder.getAffineConstantExpr(1);
SmallVector<OpFoldResult> offsetValues(2 * sourceRank);
- SmallVector<OpFoldResult> sizeValues(sourceRank);
for (unsigned i = 0; i < sourceRank; ++i) {
unsigned offsetIdx = 2 * i;
@@ -79,8 +78,7 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
mulMap = mulMap * symbols[i];
}
- // Adjust linearizedIndices, size and offset by the scale factor (dstBits /
- // srcBits).
+ // Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
int64_t scaler = dstBits / srcBits;
addMulMap = addMulMap.floorDiv(scaler);
mulMap = mulMap.floorDiv(scaler);
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 435dcc944778d..1f7797e16d317 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -177,6 +177,29 @@ func.func @memref_strided_i4(%idx : index) -> i4 {
// -----
+func.func @memref_subview_dynamic_offset_i4(%idx : index) -> i4 {
+ %c0 = arith.constant 0 : index
+ %arr = memref.alloc() : memref<512x64x8x16xi4>
+ %subview = memref.subview %arr[%idx, 0, 0, 0] [16, 64, 8, 16] [1, 1, 1, 1] : memref<512x64x8x16xi4>
+ to memref<16x64x8x16xi4, strided<[8192, 128, 16, 1], offset: ?>>
+ %ld = memref.load %subview[%c0, %c0, %c0, %c0] : memref<16x64x8x16xi4, strided<[8192, 128, 16, 1], offset: ?>>
+ return %ld : i4
+}
+
+// CHECK-LABEL: func.func @memref_subview_dynamic_offset_i4(
+// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<2097152xi8>
+// CHECK: %[[IDX:.*]] = affine.apply
+// CHECK: %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][%[[IDX]]] [65536] [1] : memref<2097152xi8> to memref<65536xi8, strided<[1], offset: ?>>
+// CHECK: memref.load %[[SUBVIEW]]
+
+// CHECK32-LABEL: func.func @memref_subview_dynamic_offset_i4(
+// CHECK32: %[[ALLOC:.*]] = memref.alloc() : memref<524288xi32>
+// CHECK32: %[[IDX:.*]] = affine.apply
+// CHECK32: %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][%[[IDX]]] [16384] [1] : memref<524288xi32> to memref<16384xi32, strided<[1], offset: ?>>
+// CHECK32: memref.load %[[SUBVIEW]]
+
+// -----
+
func.func @reinterpret_cast_memref_load_0D() -> i4 {
%0 = memref.alloc() : memref<5xi4>
%reinterpret_cast_0 = memref.reinterpret_cast %0 to offset: [0], sizes: [], strides: [] : memref<5xi4> to memref<i4>
>From 708925e271078725c2aa07ad6c2e7f6ef22babbd Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Fri, 31 May 2024 13:35:03 -0700
Subject: [PATCH 2/3] Address comments
---
.../MemRef/Transforms/EmulateNarrowType.cpp | 5 +++++
.../test/Dialect/MemRef/emulate-narrow-type.mlir | 16 ++++++++++++++--
2 files changed, 19 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 3510cc89cc358..bfe97672aaf8b 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -420,6 +420,11 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
"stride != 1 is not supported");
}
+ if (!memref::isStaticShapeAndContiguousRowMajor(subViewOp.getType())) {
+ return rewriter.notifyMatchFailure(
+ subViewOp, "the result memref type is not contiguous");
+ }
+
auto sizes = subViewOp.getStaticSizes();
int64_t lastOffset = subViewOp.getStaticOffsets().back();
// Only support static sizes and offsets.
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 1f7797e16d317..a67237b5e4dd1 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
-// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32
+// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --cse --verify-diagnostics --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=32" --cse --verify-diagnostics --split-input-file %s | FileCheck %s --check-prefix=CHECK32
// Expect no conversions.
func.func @memref_i8() -> i8 {
@@ -200,6 +200,18 @@ func.func @memref_subview_dynamic_offset_i4(%idx : index) -> i4 {
// -----
+
+func.func @negative_memref_subview_non_contiguous(%idx : index) -> i4 {
+ %c0 = arith.constant 0 : index
+ %arr = memref.alloc() : memref<40x40xi4>
+ // expected-error @+1 {{failed to legalize operation 'memref.subview' that was explicitly marked illegal}}
+ %subview = memref.subview %arr[%idx, 0] [4, 8] [1, 1] : memref<40x40xi4> to memref<4x8xi4, strided<[40, 1], offset:?>>
+ %ld = memref.load %subview[%c0, %c0] : memref<4x8xi4, strided<[40, 1], offset:?>>
+ return %ld : i4
+}
+
+// -----
+
func.func @reinterpret_cast_memref_load_0D() -> i4 {
%0 = memref.alloc() : memref<5xi4>
%reinterpret_cast_0 = memref.reinterpret_cast %0 to offset: [0], sizes: [], strides: [] : memref<5xi4> to memref<i4>
>From c0097ec2782dba6d14cddf4cf799ab51821f439b Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Mon, 3 Jun 2024 14:53:27 -0700
Subject: [PATCH 3/3] address comments
---
.../MemRef/Transforms/EmulateNarrowType.cpp | 93 ++++++++++---------
1 file changed, 51 insertions(+), 42 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index bfe97672aaf8b..2392e10522dd0 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -32,6 +32,56 @@ using namespace mlir;
// Utility functions
//===----------------------------------------------------------------------===//
+/// Converts a memref::ReinterpretCastOp to the converted type. The result
+/// MemRefType of the old op must have a rank and stride of 1, with static
+/// offset and size. The number of bits in the offset must evenly divide the
+/// bitwidth of the new converted type.
+static LogicalResult
+convertCastingOp(ConversionPatternRewriter &rewriter,
+ memref::ReinterpretCastOp::Adaptor adaptor,
+ memref::ReinterpretCastOp op, MemRefType newTy) {
+ auto convertedElementType = newTy.getElementType();
+ auto oldElementType = op.getType().getElementType();
+ int srcBits = oldElementType.getIntOrFloatBitWidth();
+ int dstBits = convertedElementType.getIntOrFloatBitWidth();
+ if (dstBits % srcBits != 0) {
+ return rewriter.notifyMatchFailure(op,
+ "only dstBits % srcBits == 0 supported");
+ }
+
+ // Only support stride of 1.
+ if (llvm::any_of(op.getStaticStrides(),
+ [](int64_t stride) { return stride != 1; })) {
+ return rewriter.notifyMatchFailure(op->getLoc(),
+ "stride != 1 is not supported");
+ }
+
+ auto sizes = op.getStaticSizes();
+ int64_t offset = op.getStaticOffset(0);
+ // Only support static sizes and offsets.
+ if (llvm::any_of(sizes,
+ [](int64_t size) { return size == ShapedType::kDynamic; }) ||
+ offset == ShapedType::kDynamic) {
+ return rewriter.notifyMatchFailure(
+ op, "dynamic size or offset is not supported");
+ }
+
+ int elementsPerByte = dstBits / srcBits;
+ if (offset % elementsPerByte != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "offset not multiple of elementsPerByte is not supported");
+ }
+
+ SmallVector<int64_t> size;
+ if (sizes.size())
+ size.push_back(ceilDiv(sizes[0], elementsPerByte));
+ offset = offset / elementsPerByte;
+
+ rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+ op, newTy, adaptor.getSource(), offset, size, op.getStaticStrides());
+ return success();
+}
+
/// When data is loaded/stored in `targetBits` granularity, but is used in
/// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
/// treated as an array of elements of width `sourceBits`.
@@ -279,48 +329,7 @@ struct ConvertMemRefReinterpretCast final
op->getLoc(), "subview with rank > 1 is not supported");
}
- Type convertedElementType = newTy.getElementType();
- Type oldElementType = op.getType().getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = convertedElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
- }
-
- // Only support stride of 1.
- if (llvm::any_of(op.getStaticStrides(),
- [](int64_t stride) { return stride != 1; })) {
- return rewriter.notifyMatchFailure(op->getLoc(),
- "stride != 1 is not supported");
- }
-
- auto sizes = op.getStaticSizes();
- int64_t offset = op.getStaticOffset(0);
- // Only support static sizes and offsets.
- if (llvm::any_of(
- sizes, [](int64_t size) { return size == ShapedType::kDynamic; }) ||
- offset == ShapedType::kDynamic) {
- return rewriter.notifyMatchFailure(
- op->getLoc(), "dynamic size or offset is not supported");
- }
-
- int elementsPerByte = dstBits / srcBits;
- if (offset % elementsPerByte != 0) {
- return rewriter.notifyMatchFailure(
- op->getLoc(), "offset not multiple of elementsPerByte is not "
- "supported");
- }
-
- SmallVector<int64_t> size;
- if (sizes.size())
- size.push_back(ceilDiv(sizes[0], elementsPerByte));
- offset = offset / elementsPerByte;
-
- rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
- op, newTy, *adaptor.getODSOperands(0).begin(), offset, size,
- op.getStaticStrides());
- return success();
+ return convertCastingOp(rewriter, adaptor, op, newTy);
}
};
More information about the Mlir-commits
mailing list