[Mlir-commits] [mlir] [mlir][memref] Fix emulate narrow types for strided memref offset (PR #68181)
Kunwar Grover
llvmlistbot at llvm.org
Wed Oct 4 22:18:31 PDT 2023
https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/68181
>From 4697178ab883acb96c65d7dedff1a507f24613ec Mon Sep 17 00:00:00 2001
From: Groverkss <groverkss at gmail.com>
Date: Wed, 4 Oct 2023 08:56:54 +0530
Subject: [PATCH 1/4] [mlir][memref] Fix emulate narrow types for strided
memref offset
---
.../MemRef/Transforms/EmulateNarrowType.cpp | 17 +++++++++++++++--
1 file changed, 15 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 2a524ceb9db887b..2a9b27debaece3f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -271,9 +271,22 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
return std::nullopt;
StridedLayoutAttr layoutAttr;
+ // If the offset is 0, we do not need a strided layout as the stride is
+ // 1, so we only use the strided layout if the offset is not 0.
if (offset != 0) {
- layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
- ArrayRef<int64_t>{1});
+ if (offset == ShapedType::kDynamic) {
+ layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
+ ArrayRef<int64_t>{1});
+ } else {
+ // Check if the number of bytes are a multiple of the loadStoreWidth
+ // and if so, divide it by the loadStoreWidth to get the offset.
+ if ((offset * width) % loadStoreWidth != 0)
+ return std::nullopt;
+ offset = (offset * width) / loadStoreWidth;
+
+ layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
+ ArrayRef<int64_t>{1});
+ }
}
return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth),
>From f9b25612e2389980166523edb774a807c2228c3a Mon Sep 17 00:00:00 2001
From: Groverkss <groverkss at gmail.com>
Date: Wed, 4 Oct 2023 11:19:55 +0530
Subject: [PATCH 2/4] Add tests and memref.subview support in emulate narrow
types
---
.../MemRef/Transforms/EmulateNarrowType.cpp | 74 ++++++++++++++++++-
.../Dialect/MemRef/emulate-narrow-type.mlir | 19 +++++
2 files changed, 90 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 2a9b27debaece3f..453a18ff3c254e6 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -209,6 +209,74 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
return success();
}
};
+
+//===----------------------------------------------------------------------===//
+// ConvertMemRefAssumeAlignment
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto convertedType =
+ cast<MemRefType>(getTypeConverter()->convertType(op.getSourceType()));
+ auto convertedElementType = convertedType.getElementType();
+ auto oldElementType = op.getSourceType().getElementType();
+ int srcBits = oldElementType.getIntOrFloatBitWidth();
+ int dstBits = convertedElementType.getIntOrFloatBitWidth();
+ if (dstBits % srcBits != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "only dstBits % srcBits == 0 supported");
+ }
+
+ MemRefType newTy =
+ cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+ if (!newTy) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(),
+ llvm::formatv("failed to convert memref type: {0}", op.getType()));
+ }
+
+ // Only support offset for 1-D subview.
+ if (op.getType().getRank() != 1) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), "subview with rank > 1 is not supported");
+ }
+
+ // Only support stride of 1.
+ if (op.getStaticStride(0) != 1) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), "subview with stride != 1 is not supported");
+ }
+
+ auto size = op.getStaticSize(0);
+ auto offset = op.getStaticOffset(0);
+ // Only support static sizes and offsets.
+ if (size == ShapedType::kDynamic || offset == ShapedType::kDynamic) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), "subview with dynamic size or offset is not supported");
+ }
+
+ int elementsPerByte = dstBits / srcBits;
+ if (size % elementsPerByte != 0 || offset % elementsPerByte != 0) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(),
+ "subview with size or offset not multiple of elementsPerByte is not "
+ "supported");
+ }
+
+ size = size / elementsPerByte;
+ offset = offset / elementsPerByte;
+
+ rewriter.replaceOpWithNewOp<memref::SubViewOp>(
+ op, newTy, *adaptor.getODSOperands(0).begin(), offset, size,
+ op.getStaticStrides());
+ return success();
+ }
+};
+
} // end anonymous namespace
//===----------------------------------------------------------------------===//
@@ -220,9 +288,9 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
RewritePatternSet &patterns) {
// Populate `memref.*` conversion patterns.
- patterns
- .add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment>(
- typeConverter, patterns.getContext());
+ patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
+ ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
+ typeConverter, patterns.getContext());
memref::populateResolveExtractStridedMetadataPatterns(patterns);
}
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index c0050d8c510d53f..6ed97f05aa7cff2 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -155,3 +155,22 @@ func.func @rank_zero_memref() -> i4 {
// CHECK32: %[[LOAD:.+]] = memref.load %[[ALLOC]][] : memref<i32>
// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[LOAD]] : i32 to i4
// CHECK32: return %[[TRUNC]]
+
+// -----
+
+func.func @memref_strided_i4(%idx : index) -> i4 {
+ %arr = memref.alloc() : memref<128xi4>
+ %subview = memref.subview %arr[32] [32] [1] : memref<128xi4> to memref<32xi4, strided<[1], offset:32>>
+ %1 = memref.load %subview[%idx] : memref<32xi4, strided<[1], offset:32>>
+ return %1 : i4
+}
+
+// CHECK-LABEL: func @memref_strided_i4
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<64xi8>
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][16] [16] [1] : memref<64xi8> to memref<16xi8, strided<[1], offset: 16>>
+// CHECK: %[[LOAD:.+]] = memref.load %[[SUBVIEW]]
+
+// CHECK32-LABEL: func @memref_strided_i4
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<16xi32>
+// CHECK32: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][4] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: 4>>
+// CHECK32: %[[LOAD:.+]] = memref.load %[[SUBVIEW]]
>From 3c3cd037829a6ed79375ae6b7d76a47dc919a58f Mon Sep 17 00:00:00 2001
From: Groverkss <groverkss at gmail.com>
Date: Wed, 4 Oct 2023 11:28:53 +0530
Subject: [PATCH 3/4] Fix doc
---
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 453a18ff3c254e6..c98dda27f6c0b5a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -211,7 +211,7 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
};
//===----------------------------------------------------------------------===//
-// ConvertMemRefAssumeAlignment
+// ConvertMemRefSubview
//===----------------------------------------------------------------------===//
struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
>From 0b1caf93c97e08fcac6bbf6e95680d6adc7a2050 Mon Sep 17 00:00:00 2001
From: Groverkss <groverkss at gmail.com>
Date: Thu, 5 Oct 2023 10:47:56 +0530
Subject: [PATCH 4/4] Address Mahesh's comments
---
.../MemRef/Transforms/EmulateNarrowType.cpp | 37 ++++++++++---------
1 file changed, 20 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index c98dda27f6c0b5a..9f58e9055acadbb 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
@@ -214,31 +215,33 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
// ConvertMemRefSubview
//===----------------------------------------------------------------------===//
+/// Emulating narrow ints on subview have limited support, supporting only
+/// static offset and size and stride of 1. Ideally, the subview should be
+/// folded away before running narrow type emulation, and this pattern would
+/// never run. This pattern is mostly used for testing pruposes.
struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto convertedType =
- cast<MemRefType>(getTypeConverter()->convertType(op.getSourceType()));
- auto convertedElementType = convertedType.getElementType();
- auto oldElementType = op.getSourceType().getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = convertedElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
- }
-
MemRefType newTy =
- cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+ dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
if (!newTy) {
return rewriter.notifyMatchFailure(
op->getLoc(),
llvm::formatv("failed to convert memref type: {0}", op.getType()));
}
+ auto convertedElementType = newTy.getElementType();
+ auto oldElementType = op.getType().getElementType();
+ int srcBits = oldElementType.getIntOrFloatBitWidth();
+ int dstBits = convertedElementType.getIntOrFloatBitWidth();
+ if (dstBits % srcBits != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "only dstBits % srcBits == 0 supported");
+ }
+
// Only support offset for 1-D subview.
if (op.getType().getRank() != 1) {
return rewriter.notifyMatchFailure(
@@ -251,8 +254,8 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
op->getLoc(), "subview with stride != 1 is not supported");
}
- auto size = op.getStaticSize(0);
- auto offset = op.getStaticOffset(0);
+ int64_t size = op.getStaticSize(0);
+ int64_t offset = op.getStaticOffset(0);
// Only support static sizes and offsets.
if (size == ShapedType::kDynamic || offset == ShapedType::kDynamic) {
return rewriter.notifyMatchFailure(
@@ -260,14 +263,14 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
}
int elementsPerByte = dstBits / srcBits;
- if (size % elementsPerByte != 0 || offset % elementsPerByte != 0) {
+ if (offset % elementsPerByte != 0) {
return rewriter.notifyMatchFailure(
op->getLoc(),
- "subview with size or offset not multiple of elementsPerByte is not "
+ "subview with offset not multiple of elementsPerByte is not "
"supported");
}
- size = size / elementsPerByte;
+ size = ceilDiv(size, elementsPerByte);
offset = offset / elementsPerByte;
rewriter.replaceOpWithNewOp<memref::SubViewOp>(
More information about the Mlir-commits
mailing list