[Mlir-commits] [mlir] [mlir] Add narrow type emulation for `memref.reinterpret_cast` (PR #72003)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 10 16:10:55 PST 2023
https://github.com/Max191 created https://github.com/llvm/llvm-project/pull/72003
None
>From 6665ad7640b308f05df83eb1e201d96a84779596 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 10 Nov 2023 18:59:35 -0500
Subject: [PATCH] [mlir] Add narrow type emulation for
`memref.reinterpret_cast`
---
.../MemRef/Transforms/EmulateNarrowType.cpp | 66 ++++++++++++++++++-
1 file changed, 63 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 9f58e9055acadbb..cafb1d4020d6dac 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -211,6 +211,65 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertMemRefReinterpretCast
+//===----------------------------------------------------------------------===//
+
+///
+struct ConvertMemRefReinterpretCast final
+ : OpConversionPattern<memref::ReinterpretCastOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ MemRefType newTy =
+ dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+ if (!newTy) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(),
+ llvm::formatv("failed to convert memref type: {0}", op.getType()));
+ }
+
+ auto convertedElementType = newTy.getElementType();
+ auto oldElementType = op.getType().getElementType();
+ int srcBits = oldElementType.getIntOrFloatBitWidth();
+ int dstBits = convertedElementType.getIntOrFloatBitWidth();
+ if (dstBits % srcBits != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "only dstBits % srcBits == 0 supported");
+ }
+
+ // Only support offset for 0-D subview.
+ if (op.getType().getRank() != 0) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), "subview with rank > 0 is not supported");
+ }
+
+ int64_t offset = op.getStaticOffset(0);
+ // Only support static sizes and offsets.
+ if (offset == ShapedType::kDynamic) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), "subview with dynamic offset is not supported");
+ }
+
+ int elementsPerByte = dstBits / srcBits;
+ if (offset % elementsPerByte != 0) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(),
+ "subview with offset not multiple of elementsPerByte is not "
+ "supported");
+ }
+
+ offset = offset / elementsPerByte;
+
+ rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+ op, newTy, *adaptor.getODSOperands(0).begin(), offset,
+ SmallVector<int64_t>{}, op.getStaticStrides());
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertMemRefSubview
//===----------------------------------------------------------------------===//
@@ -291,9 +350,10 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
RewritePatternSet &patterns) {
// Populate `memref.*` conversion patterns.
- patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
- ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
- typeConverter, patterns.getContext());
+ patterns
+ .add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment,
+ ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
+ typeConverter, patterns.getContext());
memref::populateResolveExtractStridedMetadataPatterns(patterns);
}
More information about the Mlir-commits
mailing list