[Mlir-commits] [mlir] b29332a - [mlir] Add narrow type emulation for `memref.reinterpret_cast` (#73144)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 27 10:41:18 PST 2023


Author: Max191
Date: 2023-11-27T10:41:14-08:00
New Revision: b29332a318d922aa3a626feeb3e39a312c32ba12

URL: https://github.com/llvm/llvm-project/commit/b29332a318d922aa3a626feeb3e39a312c32ba12
DIFF: https://github.com/llvm/llvm-project/commit/b29332a318d922aa3a626feeb3e39a312c32ba12.diff

LOG: [mlir] Add narrow type emulation for `memref.reinterpret_cast` (#73144)

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
    mlir/test/Dialect/MemRef/emulate-narrow-type.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 9f58e9055acadbb..dec5936fa7e83ce 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -17,11 +17,14 @@
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Support/LogicalResult.h"
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/MathExtras.h"
 #include <cassert>
+#include <type_traits>
 
 using namespace mlir;
 
@@ -29,6 +32,62 @@ using namespace mlir;
 // Utility functions
 //===----------------------------------------------------------------------===//
 
+/// Converts a memref::SubViewOp or memref::ReinterpretCastOp to the converted
+/// type. The result MemRefType of the old op must have a rank and stride of 1,
+/// with static offset and size. The number of bits in the offset must evenly
+/// divide the bitwidth of the new converted type.
+template <typename MemRefOpTy>
+static LogicalResult convertCastingOp(ConversionPatternRewriter &rewriter,
+                                      typename MemRefOpTy::Adaptor adaptor,
+                                      MemRefOpTy op, MemRefType newTy) {
+  static_assert(std::is_same<MemRefOpTy, memref::SubViewOp>() ||
+                    std::is_same<MemRefOpTy, memref::ReinterpretCastOp>(),
+                "Expected only memref::SubViewOp or memref::ReinterpretCastOp");
+
+  auto convertedElementType = newTy.getElementType();
+  auto oldElementType = op.getType().getElementType();
+  int srcBits = oldElementType.getIntOrFloatBitWidth();
+  int dstBits = convertedElementType.getIntOrFloatBitWidth();
+  if (dstBits % srcBits != 0) {
+    return rewriter.notifyMatchFailure(op,
+                                       "only dstBits % srcBits == 0 supported");
+  }
+
+  // Only support stride of 1.
+  if (llvm::any_of(op.getStaticStrides(),
+                   [](int64_t stride) { return stride != 1; })) {
+    return rewriter.notifyMatchFailure(op->getLoc(),
+                                       "stride != 1 is not supported");
+  }
+
+  auto sizes = op.getStaticSizes();
+  int64_t offset = op.getStaticOffset(0);
+  // Only support static sizes and offsets.
+  if (llvm::any_of(sizes,
+                   [](int64_t size) { return size == ShapedType::kDynamic; }) ||
+      offset == ShapedType::kDynamic) {
+    return rewriter.notifyMatchFailure(
+        op->getLoc(), "dynamic size or offset is not supported");
+  }
+
+  int elementsPerByte = dstBits / srcBits;
+  if (offset % elementsPerByte != 0) {
+    return rewriter.notifyMatchFailure(
+        op->getLoc(), "offset not multiple of elementsPerByte is not "
+                      "supported");
+  }
+
+  SmallVector<int64_t> size;
+  if (sizes.size())
+    size.push_back(ceilDiv(sizes[0], elementsPerByte));
+  offset = offset / elementsPerByte;
+
+  rewriter.replaceOpWithNewOp<MemRefOpTy>(op, newTy,
+                                          *adaptor.getODSOperands(0).begin(),
+                                          offset, size, op.getStaticStrides());
+  return success();
+}
+
 /// When data is loaded/stored in `targetBits` granularity, but is used in
 /// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
 /// treated as an array of elements of width `sourceBits`.
@@ -211,6 +270,37 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ConvertMemRefReinterpretCast
+//===----------------------------------------------------------------------===//
+
+/// Output types should be at most one dimensional, so only the 0 or 1
+/// dimensional cases are supported.
+struct ConvertMemRefReinterpretCast final
+    : OpConversionPattern<memref::ReinterpretCastOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    MemRefType newTy =
+        dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+    if (!newTy) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(),
+          llvm::formatv("failed to convert memref type: {0}", op.getType()));
+    }
+
+    // Only support for 0 or 1 dimensional cases.
+    if (op.getType().getRank() > 1) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), "subview with rank > 1 is not supported");
+    }
+
+    return convertCastingOp(rewriter, adaptor, op, newTy);
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ConvertMemRefSubview
 //===----------------------------------------------------------------------===//
@@ -233,50 +323,13 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
           llvm::formatv("failed to convert memref type: {0}", op.getType()));
     }
 
-    auto convertedElementType = newTy.getElementType();
-    auto oldElementType = op.getType().getElementType();
-    int srcBits = oldElementType.getIntOrFloatBitWidth();
-    int dstBits = convertedElementType.getIntOrFloatBitWidth();
-    if (dstBits % srcBits != 0) {
-      return rewriter.notifyMatchFailure(
-          op, "only dstBits % srcBits == 0 supported");
-    }
-
     // Only support offset for 1-D subview.
     if (op.getType().getRank() != 1) {
       return rewriter.notifyMatchFailure(
           op->getLoc(), "subview with rank > 1 is not supported");
     }
 
-    // Only support stride of 1.
-    if (op.getStaticStride(0) != 1) {
-      return rewriter.notifyMatchFailure(
-          op->getLoc(), "subview with stride != 1 is not supported");
-    }
-
-    int64_t size = op.getStaticSize(0);
-    int64_t offset = op.getStaticOffset(0);
-    // Only support static sizes and offsets.
-    if (size == ShapedType::kDynamic || offset == ShapedType::kDynamic) {
-      return rewriter.notifyMatchFailure(
-          op->getLoc(), "subview with dynamic size or offset is not supported");
-    }
-
-    int elementsPerByte = dstBits / srcBits;
-    if (offset % elementsPerByte != 0) {
-      return rewriter.notifyMatchFailure(
-          op->getLoc(),
-          "subview with offset not multiple of elementsPerByte is not "
-          "supported");
-    }
-
-    size = ceilDiv(size, elementsPerByte);
-    offset = offset / elementsPerByte;
-
-    rewriter.replaceOpWithNewOp<memref::SubViewOp>(
-        op, newTy, *adaptor.getODSOperands(0).begin(), offset, size,
-        op.getStaticStrides());
-    return success();
+    return convertCastingOp(rewriter, adaptor, op, newTy);
   }
 };
 
@@ -291,9 +344,10 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
     RewritePatternSet &patterns) {
 
   // Populate `memref.*` conversion patterns.
-  patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
-               ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
-      typeConverter, patterns.getContext());
+  patterns
+      .add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment,
+           ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
+          typeConverter, patterns.getContext());
   memref::populateResolveExtractStridedMetadataPatterns(patterns);
 }
 

diff  --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 6ed97f05aa7cff2..2c411defb47e3ba 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -174,3 +174,61 @@ func.func @memref_strided_i4(%idx : index) -> i4 {
 //       CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<16xi32>
 //       CHECK32:   %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][4] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: 4>>
 //       CHECK32:   %[[LOAD:.+]] = memref.load %[[SUBVIEW]]
+
+// -----
+
+func.func @reinterpret_cast_memref_load_0D() -> i4 {
+    %0 = memref.alloc() : memref<5xi4>
+    %reinterpret_cast_0 = memref.reinterpret_cast %0 to offset: [0], sizes: [], strides: [] : memref<5xi4> to memref<i4>
+    %1 = memref.load %reinterpret_cast_0[] : memref<i4>
+    return %1 : i4
+}
+// CHECK-LABEL: func @reinterpret_cast_memref_load_0D()
+//       CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+//       CHECK:   %[[RE_CAST:.+]] = memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [], strides: [] : memref<3xi8> to memref<i8>
+//       CHECK:   %[[LOAD:.+]] = memref.load %[[RE_CAST]][] : memref<i8>
+//       CHECK:   %[[TRUNC:.+]] = arith.trunci %[[LOAD]] : i8 to i4
+//       CHECK:   return %[[TRUNC]]
+
+// CHECK32-LABEL: func @reinterpret_cast_memref_load_0D()
+//       CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<1xi32>
+//       CHECK32:   %[[RE_CAST:.+]] = memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [], strides: [] : memref<1xi32> to memref<i32>
+//       CHECK32:   %[[LOAD:.+]] = memref.load %[[RE_CAST]][] : memref<i32>
+//       CHECK32:   %[[TRUNC:.+]] = arith.trunci %[[LOAD]] : i32 to i4
+//       CHECK32:   return %[[TRUNC]]
+
+// -----
+
+func.func @reinterpret_cast_memref_load_1D(%arg0: index) -> i4 {
+    %0 = memref.alloc() : memref<5x5xi4>
+    %reinterpret_cast_0 = memref.reinterpret_cast %0 to offset: [8], sizes: [25], strides: [1] : memref<5x5xi4> to memref<25xi4, strided<[1], offset:8>>
+    %1 = memref.load %reinterpret_cast_0[%arg0] : memref<25xi4, strided<[1], offset:8>>
+    return %1 : i4
+}
+//   CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)>
+//       CHECK: func @reinterpret_cast_memref_load_1D(
+//  CHECK-SAME: %[[ARG0:.+]]: index
+//       CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<13xi8>
+//       CHECK:   %[[RE_CAST:.+]] = memref.reinterpret_cast %[[ALLOC]] to offset: [4], sizes: [13], strides: [1] : memref<13xi8> to memref<13xi8, strided<[1], offset: 4>>
+//       CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
+//       CHECK:   %[[LOAD:.+]] = memref.load %[[RE_CAST]][%[[INDEX]]] : memref<13xi8, strided<[1], offset: 4>>
+//       CHECK:   %[[OFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+//       CHECK:   %[[CAST:.+]] = arith.index_cast %[[OFFSET]] : index to i8
+//       CHECK:   %[[SHR:.+]] = arith.shrsi %[[LOAD]], %[[CAST]] : i8
+//       CHECK:   %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i8 to i4
+//       CHECK:   return %[[TRUNC]]
+
+//   CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+//   CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)>
+//       CHECK32: func @reinterpret_cast_memref_load_1D(
+//  CHECK32-SAME: %[[ARG0:.+]]: index
+//       CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<4xi32>
+//       CHECK32:   %[[RE_CAST:.+]] = memref.reinterpret_cast %[[ALLOC]] to offset: [1], sizes: [4], strides: [1] : memref<4xi32> to memref<4xi32, strided<[1], offset: 1>>
+//       CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
+//       CHECK32:   %[[LOAD:.+]] = memref.load %[[RE_CAST]][%[[INDEX]]] : memref<4xi32, strided<[1], offset: 1>>
+//       CHECK32:   %[[OFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+//       CHECK32:   %[[CAST:.+]] = arith.index_cast %[[OFFSET]] : index to i32
+//       CHECK32:   %[[SHR:.+]] = arith.shrsi %[[LOAD]], %[[CAST]] : i32
+//       CHECK32:   %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i32 to i4
+//       CHECK32:   return %[[TRUNC]]


        


More information about the Mlir-commits mailing list