[Mlir-commits] [mlir] [mlir] Add narrow type emulation for `memref.reinterpret_cast` (PR #73144)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 22 09:14:03 PST 2023
https://github.com/Max191 created https://github.com/llvm/llvm-project/pull/73144
None
>From bf879d5066ac177f5f16bc44235c46c278bc6a47 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 10 Nov 2023 18:59:35 -0500
Subject: [PATCH] [mlir] Add narrow type emulation for
`memref.reinterpret_cast`
---
.../MemRef/Transforms/EmulateNarrowType.cpp | 136 ++++++++++++------
1 file changed, 95 insertions(+), 41 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 9f58e9055acadbb..dec5936fa7e83ce 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -17,11 +17,14 @@
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include <cassert>
+#include <type_traits>
using namespace mlir;
@@ -29,6 +32,62 @@ using namespace mlir;
// Utility functions
//===----------------------------------------------------------------------===//
+/// Converts a memref::SubViewOp or memref::ReinterpretCastOp to the converted
+/// type. The result MemRefType of the old op must have a rank and stride of 1,
+/// with static offset and size. The number of bits in the offset must evenly
+/// divide the bitwidth of the new converted type.
+template <typename MemRefOpTy>
+static LogicalResult convertCastingOp(ConversionPatternRewriter &rewriter,
+ typename MemRefOpTy::Adaptor adaptor,
+ MemRefOpTy op, MemRefType newTy) {
+ static_assert(std::is_same<MemRefOpTy, memref::SubViewOp>() ||
+ std::is_same<MemRefOpTy, memref::ReinterpretCastOp>(),
+ "Expected only memref::SubViewOp or memref::ReinterpretCastOp");
+
+ auto convertedElementType = newTy.getElementType();
+ auto oldElementType = op.getType().getElementType();
+ int srcBits = oldElementType.getIntOrFloatBitWidth();
+ int dstBits = convertedElementType.getIntOrFloatBitWidth();
+ if (dstBits % srcBits != 0) {
+ return rewriter.notifyMatchFailure(op,
+ "only dstBits % srcBits == 0 supported");
+ }
+
+ // Only support stride of 1.
+ if (llvm::any_of(op.getStaticStrides(),
+ [](int64_t stride) { return stride != 1; })) {
+ return rewriter.notifyMatchFailure(op->getLoc(),
+ "stride != 1 is not supported");
+ }
+
+ auto sizes = op.getStaticSizes();
+ int64_t offset = op.getStaticOffset(0);
+ // Only support static sizes and offsets.
+ if (llvm::any_of(sizes,
+ [](int64_t size) { return size == ShapedType::kDynamic; }) ||
+ offset == ShapedType::kDynamic) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), "dynamic size or offset is not supported");
+ }
+
+ int elementsPerByte = dstBits / srcBits;
+ if (offset % elementsPerByte != 0) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), "offset not multiple of elementsPerByte is not "
+ "supported");
+ }
+
+ SmallVector<int64_t> size;
+ if (sizes.size())
+ size.push_back(ceilDiv(sizes[0], elementsPerByte));
+ offset = offset / elementsPerByte;
+
+ rewriter.replaceOpWithNewOp<MemRefOpTy>(op, newTy,
+ *adaptor.getODSOperands(0).begin(),
+ offset, size, op.getStaticStrides());
+ return success();
+}
+
/// When data is loaded/stored in `targetBits` granularity, but is used in
/// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
/// treated as an array of elements of width `sourceBits`.
@@ -211,6 +270,37 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertMemRefReinterpretCast
+//===----------------------------------------------------------------------===//
+
+/// Output types should be at most one dimensional, so only the 0 or 1
+/// dimensional cases are supported.
+struct ConvertMemRefReinterpretCast final
+ : OpConversionPattern<memref::ReinterpretCastOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ MemRefType newTy =
+ dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+ if (!newTy) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(),
+ llvm::formatv("failed to convert memref type: {0}", op.getType()));
+ }
+
+ // Only support for 0 or 1 dimensional cases.
+ if (op.getType().getRank() > 1) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), "subview with rank > 1 is not supported");
+ }
+
+ return convertCastingOp(rewriter, adaptor, op, newTy);
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertMemRefSubview
//===----------------------------------------------------------------------===//
@@ -233,50 +323,13 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
llvm::formatv("failed to convert memref type: {0}", op.getType()));
}
- auto convertedElementType = newTy.getElementType();
- auto oldElementType = op.getType().getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = convertedElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
- }
-
// Only support offset for 1-D subview.
if (op.getType().getRank() != 1) {
return rewriter.notifyMatchFailure(
op->getLoc(), "subview with rank > 1 is not supported");
}
- // Only support stride of 1.
- if (op.getStaticStride(0) != 1) {
- return rewriter.notifyMatchFailure(
- op->getLoc(), "subview with stride != 1 is not supported");
- }
-
- int64_t size = op.getStaticSize(0);
- int64_t offset = op.getStaticOffset(0);
- // Only support static sizes and offsets.
- if (size == ShapedType::kDynamic || offset == ShapedType::kDynamic) {
- return rewriter.notifyMatchFailure(
- op->getLoc(), "subview with dynamic size or offset is not supported");
- }
-
- int elementsPerByte = dstBits / srcBits;
- if (offset % elementsPerByte != 0) {
- return rewriter.notifyMatchFailure(
- op->getLoc(),
- "subview with offset not multiple of elementsPerByte is not "
- "supported");
- }
-
- size = ceilDiv(size, elementsPerByte);
- offset = offset / elementsPerByte;
-
- rewriter.replaceOpWithNewOp<memref::SubViewOp>(
- op, newTy, *adaptor.getODSOperands(0).begin(), offset, size,
- op.getStaticStrides());
- return success();
+ return convertCastingOp(rewriter, adaptor, op, newTy);
}
};
@@ -291,9 +344,10 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
RewritePatternSet &patterns) {
// Populate `memref.*` conversion patterns.
- patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
- ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
- typeConverter, patterns.getContext());
+ patterns
+ .add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment,
+ ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
+ typeConverter, patterns.getContext());
memref::populateResolveExtractStridedMetadataPatterns(patterns);
}
More information about the Mlir-commits
mailing list