[Mlir-commits] [mlir] [mlir] Add narrow type emulation for `memref.reinterpret_cast` (PR #72003)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 10 16:10:55 PST 2023


https://github.com/Max191 created https://github.com/llvm/llvm-project/pull/72003

None

>From 6665ad7640b308f05df83eb1e201d96a84779596 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 10 Nov 2023 18:59:35 -0500
Subject: [PATCH] [mlir] Add narrow type emulation for
 `memref.reinterpret_cast`

---
 .../MemRef/Transforms/EmulateNarrowType.cpp   | 66 ++++++++++++++++++-
 1 file changed, 63 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 9f58e9055acadbb..cafb1d4020d6dac 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -211,6 +211,65 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ConvertMemRefReinterpretCast
+//===----------------------------------------------------------------------===//
+
+///
+struct ConvertMemRefReinterpretCast final
+    : OpConversionPattern<memref::ReinterpretCastOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    MemRefType newTy =
+        dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+    if (!newTy) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(),
+          llvm::formatv("failed to convert memref type: {0}", op.getType()));
+    }
+
+    auto convertedElementType = newTy.getElementType();
+    auto oldElementType = op.getType().getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = convertedElementType.getIntOrFloatBitWidth();
+    if (dstBits % srcBits != 0) {
+      return rewriter.notifyMatchFailure(
+          op, "only dstBits % srcBits == 0 supported");
+    }
+
+    // Only support offset for 0-D subview.
+    if (op.getType().getRank() != 0) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), "subview with rank > 0 is not supported");
+    }
+
+    int64_t offset = op.getStaticOffset(0);
+    // Only support static sizes and offsets.
+    if (offset == ShapedType::kDynamic) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), "subview with dynamic offset is not supported");
+    }
+
+    int elementsPerByte = dstBits / srcBits;
+    if (offset % elementsPerByte != 0) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(),
+          "subview with offset not multiple of elementsPerByte is not "
+          "supported");
+    }
+
+    offset = offset / elementsPerByte;
+
+    rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+        op, newTy, *adaptor.getODSOperands(0).begin(), offset,
+        SmallVector<int64_t>{}, op.getStaticStrides());
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ConvertMemRefSubview
 //===----------------------------------------------------------------------===//
@@ -291,9 +350,10 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
     RewritePatternSet &patterns) {
 
   // Populate `memref.*` conversion patterns.
-  patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
-               ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
-      typeConverter, patterns.getContext());
+  patterns
+      .add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment,
+           ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
+          typeConverter, patterns.getContext());
   memref::populateResolveExtractStridedMetadataPatterns(patterns);
 }
 



More information about the Mlir-commits mailing list