[Mlir-commits] [mlir] [mlir] Add support for `memref.alloca` sub-byte emulation (PR #73138)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 22 07:25:56 PST 2023


https://github.com/Max191 created https://github.com/llvm/llvm-project/pull/73138

Adds a similar case to `memref.alloc` for `memref.alloca` in EmulateNarrowTypes.

>From 94bace33ee440eb32e372e02b681285b3757b081 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 10 Nov 2023 11:27:13 -0500
Subject: [PATCH] [mlir] Add support for `memref.alloca` sub-byte emulation

---
 .../MemRef/Transforms/EmulateNarrowType.cpp   | 31 +++++++++--------
 .../Dialect/MemRef/emulate-narrow-type.mlir   | 33 +++++++++++++++++++
 2 files changed, 51 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 9f58e9055acadbb..9b197dc51265d92 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -56,15 +56,19 @@ namespace {
 // ConvertMemRefAlloc
 //===----------------------------------------------------------------------===//
 
-struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
-  using OpConversionPattern::OpConversionPattern;
+template <typename OpTy>
+struct ConvertMemRefAlloc final : OpConversionPattern<OpTy> {
+  using OpConversionPattern<OpTy>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
+  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto currentType = op.getMemref().getType().cast<MemRefType>();
-    auto newResultType =
-        getTypeConverter()->convertType(op.getType()).dyn_cast<MemRefType>();
+    static_assert(std::is_same<OpTy, memref::AllocOp>() ||
+                      std::is_same<OpTy, memref::AllocaOp>(),
+                  "expected only memref::AllocOp or memref::AllocaOp");
+    auto currentType = cast<MemRefType>(op.getMemref().getType());
+    auto newResultType = dyn_cast<MemRefType>(
+        this->getTypeConverter()->convertType(op.getType()));
     if (!newResultType) {
       return rewriter.notifyMatchFailure(
           op->getLoc(),
@@ -73,9 +77,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
 
     // Special case zero-rank memrefs.
     if (currentType.getRank() == 0) {
-      rewriter.replaceOpWithNewOp<memref::AllocOp>(
-          op, newResultType, ValueRange{}, adaptor.getSymbolOperands(),
-          adaptor.getAlignmentAttr());
+      rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
+                                        adaptor.getSymbolOperands(),
+                                        adaptor.getAlignmentAttr());
       return success();
     }
 
@@ -97,9 +101,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
           rewriter, loc, linearizedMemRefInfo.linearizedSize));
     }
 
-    rewriter.replaceOpWithNewOp<memref::AllocOp>(
-        op, newResultType, dynamicLinearizedSize, adaptor.getSymbolOperands(),
-        adaptor.getAlignmentAttr());
+    rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
+                                      adaptor.getSymbolOperands(),
+                                      adaptor.getAlignmentAttr());
     return success();
   }
 };
@@ -291,7 +295,8 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
     RewritePatternSet &patterns) {
 
   // Populate `memref.*` conversion patterns.
-  patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
+  patterns.add<ConvertMemRefAlloc<memref::AllocOp>,
+               ConvertMemRefAlloc<memref::AllocaOp>, ConvertMemRefLoad,
                ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
       typeConverter, patterns.getContext());
   memref::populateResolveExtractStridedMetadataPatterns(patterns);
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 6ed97f05aa7cff2..a25b0a668499a23 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -174,3 +174,36 @@ func.func @memref_strided_i4(%idx : index) -> i4 {
 //       CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<16xi32>
 //       CHECK32:   %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][4] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: 4>>
 //       CHECK32:   %[[LOAD:.+]] = memref.load %[[SUBVIEW]]
+
+// -----
+
+func.func @memref_alloca_load_i4(%arg0: index) -> i4 {
+    %0 = memref.alloca() : memref<5xi4>
+    %1 = memref.load %0[%arg0] : memref<5xi4>
+    return %1 : i4
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)
+//      CHECK: func @memref_alloca_load_i4(
+// CHECK-SAME:     %[[ARG0:.+]]: index
+//      CHECK:   %[[ALLOCA:.+]] = memref.alloca() : memref<3xi8>
+//      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//      CHECK:   %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
+//      CHECK:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+//      CHECK:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+//      CHECK:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+//      CHECK:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
+//      CHECK:   return %[[TRUNC]]
+
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)
+//      CHECK32: func @memref_alloca_load_i4(
+// CHECK32-SAME:     %[[ARG0:.+]]: index
+//      CHECK32:   %[[ALLOCA:.+]] = memref.alloca() : memref<1xi32>
+//      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//      CHECK32:   %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
+//      CHECK32:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+//      CHECK32:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+//      CHECK32:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+//      CHECK32:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
+//      CHECK32:   return %[[TRUNC]]



More information about the Mlir-commits mailing list