[Mlir-commits] [mlir] [mlir] Add support for `memref.alloca` sub-byte emulation (PR #73138)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 22 07:25:56 PST 2023
https://github.com/Max191 created https://github.com/llvm/llvm-project/pull/73138
Adds a similar case to `memref.alloc` for `memref.alloca` in EmulateNarrowTypes.
>From 94bace33ee440eb32e372e02b681285b3757b081 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 10 Nov 2023 11:27:13 -0500
Subject: [PATCH] [mlir] Add support for `memref.alloca` sub-byte emulation
---
.../MemRef/Transforms/EmulateNarrowType.cpp | 31 +++++++++--------
.../Dialect/MemRef/emulate-narrow-type.mlir | 33 +++++++++++++++++++
2 files changed, 51 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 9f58e9055acadbb..9b197dc51265d92 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -56,15 +56,19 @@ namespace {
// ConvertMemRefAlloc
//===----------------------------------------------------------------------===//
-struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
- using OpConversionPattern::OpConversionPattern;
+template <typename OpTy>
+struct ConvertMemRefAlloc final : OpConversionPattern<OpTy> {
+ using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult
- matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
+ matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto currentType = op.getMemref().getType().cast<MemRefType>();
- auto newResultType =
- getTypeConverter()->convertType(op.getType()).dyn_cast<MemRefType>();
+ static_assert(std::is_same<OpTy, memref::AllocOp>() ||
+ std::is_same<OpTy, memref::AllocaOp>(),
+ "expected only memref::AllocOp or memref::AllocaOp");
+ auto currentType = cast<MemRefType>(op.getMemref().getType());
+ auto newResultType = dyn_cast<MemRefType>(
+ this->getTypeConverter()->convertType(op.getType()));
if (!newResultType) {
return rewriter.notifyMatchFailure(
op->getLoc(),
@@ -73,9 +77,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
// Special case zero-rank memrefs.
if (currentType.getRank() == 0) {
- rewriter.replaceOpWithNewOp<memref::AllocOp>(
- op, newResultType, ValueRange{}, adaptor.getSymbolOperands(),
- adaptor.getAlignmentAttr());
+ rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
+ adaptor.getSymbolOperands(),
+ adaptor.getAlignmentAttr());
return success();
}
@@ -97,9 +101,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
rewriter, loc, linearizedMemRefInfo.linearizedSize));
}
- rewriter.replaceOpWithNewOp<memref::AllocOp>(
- op, newResultType, dynamicLinearizedSize, adaptor.getSymbolOperands(),
- adaptor.getAlignmentAttr());
+ rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
+ adaptor.getSymbolOperands(),
+ adaptor.getAlignmentAttr());
return success();
}
};
@@ -291,7 +295,8 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
RewritePatternSet &patterns) {
// Populate `memref.*` conversion patterns.
- patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
+ patterns.add<ConvertMemRefAlloc<memref::AllocOp>,
+ ConvertMemRefAlloc<memref::AllocaOp>, ConvertMemRefLoad,
ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
typeConverter, patterns.getContext());
memref::populateResolveExtractStridedMetadataPatterns(patterns);
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 6ed97f05aa7cff2..a25b0a668499a23 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -174,3 +174,36 @@ func.func @memref_strided_i4(%idx : index) -> i4 {
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<16xi32>
// CHECK32: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][4] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: 4>>
// CHECK32: %[[LOAD:.+]] = memref.load %[[SUBVIEW]]
+
+// -----
+
+func.func @memref_alloca_load_i4(%arg0: index) -> i4 {
+ %0 = memref.alloca() : memref<5xi4>
+ %1 = memref.load %0[%arg0] : memref<5xi4>
+ return %1 : i4
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)
+// CHECK: func @memref_alloca_load_i4(
+// CHECK-SAME: %[[ARG0:.+]]: index
+// CHECK: %[[ALLOCA:.+]] = memref.alloca() : memref<3xi8>
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK: %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
+// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
+// CHECK: return %[[TRUNC]]
+
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)
+// CHECK32: func @memref_alloca_load_i4(
+// CHECK32-SAME: %[[ARG0:.+]]: index
+// CHECK32: %[[ALLOCA:.+]] = memref.alloca() : memref<1xi32>
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK32: %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
+// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
+// CHECK32: return %[[TRUNC]]
More information about the Mlir-commits
mailing list