[llvm-branch-commits] [mlir] b823f84 - [mlir] Add support for `memref.alloca` sub-byte emulation (#73138)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Nov 28 22:23:38 PST 2023
Author: Max191
Date: 2023-11-27T16:28:22-08:00
New Revision: b823f8469b5364411cde31a215c9bcbe0d3c08f7
URL: https://github.com/llvm/llvm-project/commit/b823f8469b5364411cde31a215c9bcbe0d3c08f7
DIFF: https://github.com/llvm/llvm-project/commit/b823f8469b5364411cde31a215c9bcbe0d3c08f7.diff
LOG: [mlir] Add support for `memref.alloca` sub-byte emulation (#73138)
Adds a similar case to `memref.alloc` for `memref.alloca` in
EmulateNarrowTypes.
Fixes https://github.com/openxla/iree/issues/15515
Added:
Modified:
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index dec5936fa7e83ce..e5801c3733ed5a8 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -112,18 +112,22 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
namespace {
//===----------------------------------------------------------------------===//
-// ConvertMemRefAlloc
+// ConvertMemRefAllocation
//===----------------------------------------------------------------------===//
-struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
- using OpConversionPattern::OpConversionPattern;
+template <typename OpTy>
+struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
+ using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult
- matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
+ matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto currentType = op.getMemref().getType().cast<MemRefType>();
- auto newResultType =
- getTypeConverter()->convertType(op.getType()).dyn_cast<MemRefType>();
+ static_assert(std::is_same<OpTy, memref::AllocOp>() ||
+ std::is_same<OpTy, memref::AllocaOp>(),
+ "expected only memref::AllocOp or memref::AllocaOp");
+ auto currentType = cast<MemRefType>(op.getMemref().getType());
+ auto newResultType = dyn_cast<MemRefType>(
+ this->getTypeConverter()->convertType(op.getType()));
if (!newResultType) {
return rewriter.notifyMatchFailure(
op->getLoc(),
@@ -132,9 +136,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
// Special case zero-rank memrefs.
if (currentType.getRank() == 0) {
- rewriter.replaceOpWithNewOp<memref::AllocOp>(
- op, newResultType, ValueRange{}, adaptor.getSymbolOperands(),
- adaptor.getAlignmentAttr());
+ rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
+ adaptor.getSymbolOperands(),
+ adaptor.getAlignmentAttr());
return success();
}
@@ -156,9 +160,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
rewriter, loc, linearizedMemRefInfo.linearizedSize));
}
- rewriter.replaceOpWithNewOp<memref::AllocOp>(
- op, newResultType, dynamicLinearizedSize, adaptor.getSymbolOperands(),
- adaptor.getAlignmentAttr());
+ rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
+ adaptor.getSymbolOperands(),
+ adaptor.getAlignmentAttr());
return success();
}
};
@@ -344,10 +348,11 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
RewritePatternSet &patterns) {
// Populate `memref.*` conversion patterns.
- patterns
- .add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment,
- ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
- typeConverter, patterns.getContext());
+ patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
+ ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefLoad,
+ ConvertMemRefAssumeAlignment, ConvertMemRefSubview,
+ ConvertMemRefReinterpretCast>(typeConverter,
+ patterns.getContext());
memref::populateResolveExtractStridedMetadataPatterns(patterns);
}
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 2c411defb47e3ba..dc32a59a1a14931 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -232,3 +232,36 @@ func.func @reinterpret_cast_memref_load_1D(%arg0: index) -> i4 {
// CHECK32: %[[SHR:.+]] = arith.shrsi %[[LOAD]], %[[CAST]] : i32
// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i32 to i4
// CHECK32: return %[[TRUNC]]
+
+// -----
+
+func.func @memref_alloca_load_i4(%arg0: index) -> i4 {
+ %0 = memref.alloca() : memref<5xi4>
+ %1 = memref.load %0[%arg0] : memref<5xi4>
+ return %1 : i4
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)
+// CHECK: func @memref_alloca_load_i4(
+// CHECK-SAME: %[[ARG0:.+]]: index
+// CHECK: %[[ALLOCA:.+]] = memref.alloca() : memref<3xi8>
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK: %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
+// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
+// CHECK: return %[[TRUNC]]
+
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)
+// CHECK32: func @memref_alloca_load_i4(
+// CHECK32-SAME: %[[ARG0:.+]]: index
+// CHECK32: %[[ALLOCA:.+]] = memref.alloca() : memref<1xi32>
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK32: %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
+// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
+// CHECK32: return %[[TRUNC]]
More information about the llvm-branch-commits
mailing list