[llvm-branch-commits] [mlir] b823f84 - [mlir] Add support for `memref.alloca` sub-byte emulation (#73138)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Nov 28 22:23:38 PST 2023


Author: Max191
Date: 2023-11-27T16:28:22-08:00
New Revision: b823f8469b5364411cde31a215c9bcbe0d3c08f7

URL: https://github.com/llvm/llvm-project/commit/b823f8469b5364411cde31a215c9bcbe0d3c08f7
DIFF: https://github.com/llvm/llvm-project/commit/b823f8469b5364411cde31a215c9bcbe0d3c08f7.diff

LOG: [mlir] Add support for `memref.alloca` sub-byte emulation (#73138)

Adds a similar case to `memref.alloc` for `memref.alloca` in
EmulateNarrowTypes.

Fixes https://github.com/openxla/iree/issues/15515

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
    mlir/test/Dialect/MemRef/emulate-narrow-type.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index dec5936fa7e83ce..e5801c3733ed5a8 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -112,18 +112,22 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
 namespace {
 
 //===----------------------------------------------------------------------===//
-// ConvertMemRefAlloc
+// ConvertMemRefAllocation
 //===----------------------------------------------------------------------===//
 
-struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
-  using OpConversionPattern::OpConversionPattern;
+template <typename OpTy>
+struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
+  using OpConversionPattern<OpTy>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
+  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto currentType = op.getMemref().getType().cast<MemRefType>();
-    auto newResultType =
-        getTypeConverter()->convertType(op.getType()).dyn_cast<MemRefType>();
+    static_assert(std::is_same<OpTy, memref::AllocOp>() ||
+                      std::is_same<OpTy, memref::AllocaOp>(),
+                  "expected only memref::AllocOp or memref::AllocaOp");
+    auto currentType = cast<MemRefType>(op.getMemref().getType());
+    auto newResultType = dyn_cast<MemRefType>(
+        this->getTypeConverter()->convertType(op.getType()));
     if (!newResultType) {
       return rewriter.notifyMatchFailure(
           op->getLoc(),
@@ -132,9 +136,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
 
     // Special case zero-rank memrefs.
     if (currentType.getRank() == 0) {
-      rewriter.replaceOpWithNewOp<memref::AllocOp>(
-          op, newResultType, ValueRange{}, adaptor.getSymbolOperands(),
-          adaptor.getAlignmentAttr());
+      rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
+                                        adaptor.getSymbolOperands(),
+                                        adaptor.getAlignmentAttr());
       return success();
     }
 
@@ -156,9 +160,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
           rewriter, loc, linearizedMemRefInfo.linearizedSize));
     }
 
-    rewriter.replaceOpWithNewOp<memref::AllocOp>(
-        op, newResultType, dynamicLinearizedSize, adaptor.getSymbolOperands(),
-        adaptor.getAlignmentAttr());
+    rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
+                                      adaptor.getSymbolOperands(),
+                                      adaptor.getAlignmentAttr());
     return success();
   }
 };
@@ -344,10 +348,11 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
     RewritePatternSet &patterns) {
 
   // Populate `memref.*` conversion patterns.
-  patterns
-      .add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment,
-           ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
-          typeConverter, patterns.getContext());
+  patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
+               ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefLoad,
+               ConvertMemRefAssumeAlignment, ConvertMemRefSubview,
+               ConvertMemRefReinterpretCast>(typeConverter,
+                                             patterns.getContext());
   memref::populateResolveExtractStridedMetadataPatterns(patterns);
 }
 

diff  --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 2c411defb47e3ba..dc32a59a1a14931 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -232,3 +232,36 @@ func.func @reinterpret_cast_memref_load_1D(%arg0: index) -> i4 {
 //       CHECK32:   %[[SHR:.+]] = arith.shrsi %[[LOAD]], %[[CAST]] : i32
 //       CHECK32:   %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i32 to i4
 //       CHECK32:   return %[[TRUNC]]
+
+// -----
+
+func.func @memref_alloca_load_i4(%arg0: index) -> i4 {
+    %0 = memref.alloca() : memref<5xi4>
+    %1 = memref.load %0[%arg0] : memref<5xi4>
+    return %1 : i4
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)
+//      CHECK: func @memref_alloca_load_i4(
+// CHECK-SAME:     %[[ARG0:.+]]: index
+//      CHECK:   %[[ALLOCA:.+]] = memref.alloca() : memref<3xi8>
+//      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//      CHECK:   %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
+//      CHECK:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+//      CHECK:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+//      CHECK:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+//      CHECK:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
+//      CHECK:   return %[[TRUNC]]
+
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)
+//      CHECK32: func @memref_alloca_load_i4(
+// CHECK32-SAME:     %[[ARG0:.+]]: index
+//      CHECK32:   %[[ALLOCA:.+]] = memref.alloca() : memref<1xi32>
+//      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//      CHECK32:   %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
+//      CHECK32:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+//      CHECK32:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+//      CHECK32:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+//      CHECK32:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
+//      CHECK32:   return %[[TRUNC]]


        


More information about the llvm-branch-commits mailing list