[Mlir-commits] [mlir] [mlir][MemRef] Add more ops to narrow type support, strided metadata expansion (PR #102228)
Krzysztof Drewniak
llvmlistbot at llvm.org
Wed Aug 14 11:46:16 PDT 2024
https://github.com/krzysz00 updated https://github.com/llvm/llvm-project/pull/102228
>From 54a8fb251a45d63aa350c4cdad6c35a832cf8627 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Tue, 16 Jul 2024 22:47:28 +0000
Subject: [PATCH] [mlir][MemRef] Add more ops to narrow type support, strided
metadata
- Add support fef memory_space_cast to strided metadata expansion and
narrow type emulation
- Add support for expand_shape to narrow type emulation (like
collapse_shape, it's a noop after linearization) and to
expand-strided-metadata (mirroring the collapse_shape pattern)
- Add support for memref.dealloc to narrow type emulation (it is a
trivial rewrite) and for memref.copy (which is unsupported when it is
used for a layout change but a trivial rewrite otherwise)
---
.../MemRef/Transforms/EmulateNarrowType.cpp | 93 ++++++++++++++++++-
.../Transforms/ExpandStridedMetadata.cpp | 87 +++++++++++++++++
.../Dialect/MemRef/emulate-narrow-type.mlir | 68 ++++++++++++++
.../MemRef/expand-strided-metadata.mlir | 38 ++++++++
4 files changed, 283 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 88d56a8fbec749..a45b79194a7580 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -234,6 +234,46 @@ struct ConvertMemRefAssumeAlignment final
}
};
+//===----------------------------------------------------------------------===//
+// ConvertMemRefCopy
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefCopy final : OpConversionPattern<memref::CopyOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto maybeRankedSource = dyn_cast<MemRefType>(op.getSource().getType());
+ auto maybeRankedDest = dyn_cast<MemRefType>(op.getTarget().getType());
+ if (maybeRankedSource && maybeRankedDest &&
+ maybeRankedSource.getLayout() != maybeRankedDest.getLayout())
+ return rewriter.notifyMatchFailure(
+ op, llvm::formatv("memref.copy emulation with distinct layouts ({0} "
+ "and {1}) is currently unimplemented",
+ maybeRankedSource.getLayout(),
+ maybeRankedDest.getLayout()));
+ rewriter.replaceOpWithNewOp<memref::CopyOp>(op, adaptor.getSource(),
+ adaptor.getTarget());
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertMemRefDealloc
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefDealloc final : OpConversionPattern<memref::DeallocOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<memref::DeallocOp>(op, adaptor.getMemref());
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertMemRefLoad
//===----------------------------------------------------------------------===//
@@ -300,6 +340,30 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertMemRefMemorySpaceCast
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefMemorySpaceCast final
+ : OpConversionPattern<memref::MemorySpaceCastOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type newTy = getTypeConverter()->convertType(op.getDest().getType());
+ if (!newTy) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
+ op.getDest().getType()));
+ }
+
+ rewriter.replaceOpWithNewOp<memref::MemorySpaceCastOp>(op, newTy,
+ adaptor.getSource());
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertMemRefReinterpretCast
//===----------------------------------------------------------------------===//
@@ -490,6 +554,28 @@ struct ConvertMemRefCollapseShape final
}
};
+/// Emulating a `memref.expand_shape` becomes a no-op after emulation given
+/// that we flatten memrefs to a single dimension as part of the emulation and
+/// the expansion would just have been undone.
+struct ConvertMemRefExpandShape final
+ : OpConversionPattern<memref::ExpandShapeOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::ExpandShapeOp expandShapeOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Value srcVal = adaptor.getSrc();
+ auto newTy = dyn_cast<MemRefType>(srcVal.getType());
+ if (!newTy)
+ return failure();
+
+ if (newTy.getRank() != 1)
+ return failure();
+
+ rewriter.replaceOp(expandShapeOp, srcVal);
+ return success();
+ }
+};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
@@ -502,9 +588,10 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
// Populate `memref.*` conversion patterns.
patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
- ConvertMemRefAllocation<memref::AllocaOp>,
- ConvertMemRefCollapseShape, ConvertMemRefLoad,
- ConvertMemrefStore, ConvertMemRefAssumeAlignment,
+ ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefCopy,
+ ConvertMemRefDealloc, ConvertMemRefCollapseShape,
+ ConvertMemRefExpandShape, ConvertMemRefLoad, ConvertMemrefStore,
+ ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast,
ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
typeConverter, patterns.getContext());
memref::populateResolveExtractStridedMetadataPatterns(patterns);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 585c5b73814219..a2049ba4a4924d 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -726,6 +726,41 @@ struct ExtractStridedMetadataOpCollapseShapeFolder
}
};
+/// Pattern to replace `extract_strided_metadata(expand_shape)`
+/// with the results of computing the sizes and strides on the expanded shape
+/// and dividing up dimensions into static and dynamic parts as needed.
+struct ExtractStridedMetadataOpExpandShapeFolder
+ : OpRewritePattern<memref::ExtractStridedMetadataOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
+ PatternRewriter &rewriter) const override {
+ auto expandShapeOp = op.getSource().getDefiningOp<memref::ExpandShapeOp>();
+ if (!expandShapeOp)
+ return failure();
+
+ FailureOr<StridedMetadata> stridedMetadata =
+ resolveReshapeStridedMetadata<memref::ExpandShapeOp>(
+ rewriter, expandShapeOp, getExpandedSizes, getExpandedStrides);
+ if (failed(stridedMetadata)) {
+ return rewriter.notifyMatchFailure(
+ op, "failed to resolve metadata in terms of source expand_shape op");
+ }
+
+ Location loc = expandShapeOp.getLoc();
+ SmallVector<Value> results;
+ results.push_back(stridedMetadata->basePtr);
+ results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
+ stridedMetadata->offset));
+ results.append(
+ getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
+ results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
+ stridedMetadata->strides));
+ rewriter.replaceOp(op, results);
+ return success();
+ }
+};
+
/// Replace `base, offset, sizes, strides =
/// extract_strided_metadata(allocLikeOp)`
///
@@ -1060,6 +1095,54 @@ class ExtractStridedMetadataOpCastFolder
}
};
+/// Replace `base, offset, sizes, strides = extract_strided_metadata(
+/// memory_space_cast(src) to dstTy)`
+/// with
+/// ```
+/// oldBase, offset, sizes, strides = extract_strided_metadata(src)
+/// destBaseTy = type(oldBase) with memory space from destTy
+/// base = memory_space_cast(oldBase) to destBaseTy
+/// ```
+///
+/// In other words, propagate metadata extraction accross memory space casts.
+class ExtractStridedMetadataOpMemorySpaceCastFolder
+ : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult
+ matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
+ PatternRewriter &rewriter) const override {
+ Location loc = extractStridedMetadataOp.getLoc();
+ Value source = extractStridedMetadataOp.getSource();
+ auto memSpaceCastOp = source.getDefiningOp<memref::MemorySpaceCastOp>();
+ if (!memSpaceCastOp)
+ return failure();
+ auto newExtractStridedMetadata =
+ rewriter.create<memref::ExtractStridedMetadataOp>(
+ loc, memSpaceCastOp.getSource());
+ SmallVector<Value> results(newExtractStridedMetadata.getResults());
+ // As with most other strided metadata rewrite patterns, don't introduce
+ // a use of the base pointer where non existed. This needs to happen here,
+ // as opposed to in later dead-code elimination, because these patterns are
+ // sometimes used during dialect conversion (see EmulateNarrowType, for
+ // example), so adding spurious usages would cause a pre-legalization value
+ // to be live that would be dead had this pattern not run.
+ if (!extractStridedMetadataOp.getBaseBuffer().use_empty()) {
+ auto baseBuffer = results[0];
+ auto baseBufferType = cast<MemRefType>(baseBuffer.getType());
+ MemRefType::Builder newTypeBuilder(baseBufferType);
+ newTypeBuilder.setMemorySpace(
+ memSpaceCastOp.getResult().getType().getMemorySpace());
+ results[0] = rewriter.create<memref::MemorySpaceCastOp>(
+ loc, Type{newTypeBuilder}, baseBuffer);
+ } else {
+ results[0] = nullptr;
+ }
+ rewriter.replaceOp(extractStridedMetadataOp, results);
+ return success();
+ }
+};
+
/// Replace `base, offset =
/// extract_strided_metadata(extract_strided_metadata(src)#0)`
/// With
@@ -1099,11 +1182,13 @@ void memref::populateExpandStridedMetadataPatterns(
ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
ExtractStridedMetadataOpCollapseShapeFolder,
+ ExtractStridedMetadataOpExpandShapeFolder,
ExtractStridedMetadataOpGetGlobalFolder,
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
ExtractStridedMetadataOpReinterpretCastFolder,
ExtractStridedMetadataOpSubviewFolder,
ExtractStridedMetadataOpCastFolder,
+ ExtractStridedMetadataOpMemorySpaceCastFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
patterns.getContext());
}
@@ -1113,11 +1198,13 @@ void memref::populateResolveExtractStridedMetadataPatterns(
patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
ExtractStridedMetadataOpCollapseShapeFolder,
+ ExtractStridedMetadataOpExpandShapeFolder,
ExtractStridedMetadataOpGetGlobalFolder,
ExtractStridedMetadataOpSubviewFolder,
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
ExtractStridedMetadataOpReinterpretCastFolder,
ExtractStridedMetadataOpCastFolder,
+ ExtractStridedMetadataOpMemorySpaceCastFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
patterns.getContext());
}
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index a67237b5e4dd19..540da239fced08 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -6,11 +6,13 @@ func.func @memref_i8() -> i8 {
%c3 = arith.constant 3 : index
%m = memref.alloc() : memref<4xi8, 1>
%v = memref.load %m[%c3] : memref<4xi8, 1>
+ memref.dealloc %m : memref<4xi8, 1>
return %v : i8
}
// CHECK-LABEL: func @memref_i8()
// CHECK: %[[M:.+]] = memref.alloc() : memref<4xi8, 1>
// CHECK-NEXT: %[[V:.+]] = memref.load %[[M]][%{{.+}}] : memref<4xi8, 1>
+// CHECK-NEXT: memref.dealloc %[[M]]
// CHECK-NEXT: return %[[V]]
// CHECK32-LABEL: func @memref_i8()
@@ -21,6 +23,7 @@ func.func @memref_i8() -> i8 {
// CHECK32: %[[CAST:.+]] = arith.index_cast %[[C24]] : index to i32
// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[V]], %[[CAST]]
// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i8
+// CHECK32-NEXT: memref.dealloc %[[M]]
// CHECK32-NEXT: return %[[TRUNC]]
// -----
@@ -485,3 +488,68 @@ func.func @memref_collapse_shape_i4(%idx0 : index, %idx1 : index) -> i4 {
// CHECK32-NOT: memref.collapse_shape
// CHECK32: memref.load %[[ALLOC]][%{{.*}}] : memref<4096xi32>
+// -----
+
+func.func @memref_expand_shape_i4(%idx0 : index, %idx1 : index, %idx2 : index) -> i4 {
+ %arr = memref.alloc() : memref<256x128xi4>
+ %expand = memref.expand_shape %arr[[0, 1], [2]] output_shape [32, 8, 128] : memref<256x128xi4> into memref<32x8x128xi4>
+ %1 = memref.load %expand[%idx0, %idx1, %idx2] : memref<32x8x128xi4>
+ return %1 : i4
+}
+
+// CHECK-LABEL: func.func @memref_expand_shape_i4(
+// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<16384xi8>
+// CHECK-NOT: memref.expand_shape
+// CHECK: memref.load %[[ALLOC]][%{{.*}}] : memref<16384xi8>
+
+// CHECK32-LABEL: func.func @memref_expand_shape_i4(
+// CHECK32: %[[ALLOC:.*]] = memref.alloc() : memref<4096xi32>
+// CHECK32-NOT: memref.expand_shape
+// CHECK32: memref.load %[[ALLOC]][%{{.*}}] : memref<4096xi32>
+
+// -----
+
+func.func @memref_memory_space_cast_i4(%arg0: memref<32x128xi4, 1>) -> memref<32x128xi4> {
+ %cast = memref.memory_space_cast %arg0 : memref<32x128xi4, 1> to memref<32x128xi4>
+ return %cast : memref<32x128xi4>
+}
+
+// CHECK-LABEL: func.func @memref_memory_space_cast_i4(
+// CHECK-SAME: %[[ARG0:.*]]: memref<2048xi8, 1>
+// CHECK: %[[CAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<2048xi8, 1> to memref<2048xi8>
+// CHECK: return %[[CAST]]
+
+// CHECK32-LABEL: func.func @memref_memory_space_cast_i4(
+// CHECK32-SAME: %[[ARG0:.*]]: memref<512xi32, 1>
+// CHECK32: %[[CAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<512xi32, 1> to memref<512xi32>
+// CHECK32: return %[[CAST]]
+
+// -----
+
+func.func @memref_copy_i4(%arg0: memref<32x128xi4, 1>, %arg1: memref<32x128xi4>) {
+ memref.copy %arg0, %arg1 : memref<32x128xi4, 1> to memref<32x128xi4>
+ return
+}
+
+// CHECK-LABEL: func.func @memref_copy_i4(
+// CHECK-SAME: %[[ARG0:.*]]: memref<2048xi8, 1>, %[[ARG1:.*]]: memref<2048xi8>
+// CHECK: memref.copy %[[ARG0]], %[[ARG1]]
+// CHECK: return
+
+// CHECK32-LABEL: func.func @memref_copy_i4(
+// CHECK32-SAME: %[[ARG0:.*]]: memref<512xi32, 1>, %[[ARG1:.*]]: memref<512xi32>
+// CHECK32: memref.copy %[[ARG0]], %[[ARG1]]
+// CHECK32: return
+
+// -----
+
+!colMajor = memref<8x8xi4, strided<[1, 8]>>
+func.func @copy_distinct_layouts(%idx : index) -> i4 {
+ %c0 = arith.constant 0 : index
+ %arr = memref.alloc() : memref<8x8xi4>
+ %arr2 = memref.alloc() : !colMajor
+ // expected-error @+1 {{failed to legalize operation 'memref.copy' that was explicitly marked illegal}}
+ memref.copy %arr, %arr2 : memref<8x8xi4> to !colMajor
+ %ld = memref.load %arr2[%c0, %c0] : !colMajor
+ return %ld : i4
+}
diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index d884ade3195329..8aac802ba10ae9 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -1553,3 +1553,41 @@ func.func @extract_strided_metadata_of_collapse_shape(%base: memref<5x4xf32>)
// CHECK-DAG: %[[STEP:.*]] = arith.constant 1 : index
// CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata
// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZE]], %[[STEP]] : memref<f32>, index, index, index
+
+// -----
+
+func.func @extract_strided_metadata_of_memory_space_cast(%base: memref<20xf32>)
+ -> (memref<f32, 1>, index, index, index) {
+
+ %memory_space_cast = memref.memory_space_cast %base : memref<20xf32> to memref<20xf32, 1>
+
+ %base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %memory_space_cast :
+ memref<20xf32, 1> -> memref<f32, 1>, index, index, index
+
+ return %base_buffer, %offset, %size, %stride :
+ memref<f32, 1>, index, index, index
+}
+
+// CHECK-LABEL: func @extract_strided_metadata_of_memory_space_cast
+// CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[SIZE:.*]] = arith.constant 20 : index
+// CHECK-DAG: %[[STEP:.*]] = arith.constant 1 : index
+// CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata
+// CHECK: %[[CAST:.*]] = memref.memory_space_cast %[[BASE]]
+// CHECK: return %[[CAST]], %[[OFFSET]], %[[SIZE]], %[[STEP]] : memref<f32, 1>, index, index, index
+
+// -----
+
+func.func @extract_strided_metadata_of_memory_space_cast_no_base(%base: memref<20xf32>)
+ -> (index, index, index) {
+
+ %memory_space_cast = memref.memory_space_cast %base : memref<20xf32> to memref<20xf32, 1>
+
+ %base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %memory_space_cast :
+ memref<20xf32, 1> -> memref<f32, 1>, index, index, index
+
+ return %offset, %size, %stride : index, index, index
+}
+
+// CHECK-LABEL: func @extract_strided_metadata_of_memory_space_cast_no_base
+// CHECK-NOT: memref.memory_space_cast
More information about the Mlir-commits
mailing list