[Mlir-commits] [mlir] 5083e80 - Folding extract_strided_metadata input into reinterpret_cast (#134845)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 9 07:50:22 PDT 2025
Author: ivangarcia44
Date: 2025-04-09T16:50:16+02:00
New Revision: 5083e80c14a5c1f0ab40b5df95771ebbdda1adb2
URL: https://github.com/llvm/llvm-project/commit/5083e80c14a5c1f0ab40b5df95771ebbdda1adb2
DIFF: https://github.com/llvm/llvm-project/commit/5083e80c14a5c1f0ab40b5df95771ebbdda1adb2.diff
LOG: Folding extract_strided_metadata input into reinterpret_cast (#134845)
We can always fold the input of a extract_strided_metadata operator to
the input of a reinterpret_cast operator, because they point to the same
memory. Note that the reinterpret_cast does not use the layout of its
input memref, only its base memory pointer which is the same as the base
pointer returned by the extract_strided_metadata operator and the base
pointer of the extract_strided_metadata memref input.
Operations like expand_shape, collapse_shape, and subview are lowered to
a pair of extract_strided_metadata and reinterpret_cast like this:
%base_buffer, %offset, %sizes:2, %strides:2 =
memref.extract_strided_metadata %input_memref :
memref<ID1x...xIDNxBaseType> -> memref<f32>, index, index, index, index,
index
%reinterpret_cast = memref.reinterpret_cast %base_buffer to offset:
[%o1], sizes: [%d1,...,%dN], strides: [%s1,...,%N] : memref<f32> to
memref<OD1x...xODNxBaseType >
In many cases the input of the extract_strided_metadata input can be
passed directly into the input of the reinterpret_cast operation like
this (see how %base_buffer is replaced by %input_memref in the
reinterpret_cast above and the input type is updated):
%base_buffer, %offset, %sizes:2, %strides:2 =
memref.extract_strided_metadata %input_memref :
memref<ID1x...xIDNxBaseType> -> memref<f32>, index, index, index, index,
index
%reinterpret_cast = memref.reinterpret_cast %input_memref to offset:
[%o1], sizes: [%d1,...,%dN], strides: [%s1,...,%N] :
memref<ID1x...xIDNxBaseType> to memref<OD1x...xODNxBaseType >
When dealing with static dimensions, the extract_strided_metatdata will
become deadcode and we end up only with a reinterpret_cast:
%reinterpret_cast = memref.reinterpret_cast %input_memref to offset:
[%o1], sizes: [%d1,...,%dN], strides: [%s1,...,%N] :
memref<ID1x...xIDNxBaseType> to memref<OD1x...xODNxBaseType >
Note that reinterpret_cast only reads the base memory pointer from the
input memref (%input_memref above), which is equivalent to the
%base_buffer returned by the extract_strided_metadata operation. Hence
it is legal always to use the extract_strided_metadata input memref
directly in the reinterpret_cast. Note that since this is a pointer,
this operation is legal even when the base pointer values are modified
between the operation pair.
@matthias-springer
@joker-eph
@sahas3
@Hanumanth04
@dixinzhou
@rafaelubalmw
---------
Co-authored-by: Ivan Garcia <igarcia at vdi-ah2ddp-178.dhcp.mathworks.com>
Added:
Modified:
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 123666848f83a..63f5251398716 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1124,7 +1124,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
}
} // else dim.getIndex is a block argument to reshape->getBlock and
// dominates reshape
- } // Check condition 2
+ } // Check condition 2
else if (dim->getBlock() != reshape->getBlock() &&
!dim.getIndex().getParentRegion()->isProperAncestor(
reshape->getParentRegion())) {
@@ -2034,6 +2034,11 @@ namespace {
/// ```
/// Because we know that `offset`and `c0` will hold 0
/// and `c4` will hold 4.
+///
+/// If the pattern above does not match, the input of the
+/// extract_strided_metadata is always folded into the input of the
+/// reinterpret_cast operator. This allows for dead code elimination to get rid
+/// of the extract_strided_metadata in some cases.
struct ReinterpretCastOpExtractStridedMetadataFolder
: public OpRewritePattern<ReinterpretCastOp> {
public:
@@ -2045,44 +2050,49 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
if (!extractStridedMetadata)
return failure();
+
// Check if the reinterpret cast reconstructs a memref with the exact same
// properties as the extract strided metadata.
+ auto isReinterpretCastNoop = [&]() -> bool {
+ // First, check that the strides are the same.
+ if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
+ op.getConstifiedMixedStrides()))
+ return false;
- // First, check that the strides are the same.
- SmallVector<OpFoldResult> extractStridesOfr =
- extractStridedMetadata.getConstifiedMixedStrides();
- SmallVector<OpFoldResult> reinterpretStridesOfr =
- op.getConstifiedMixedStrides();
- if (extractStridesOfr.size() != reinterpretStridesOfr.size())
- return failure();
-
- unsigned rank = op.getType().getRank();
- for (unsigned i = 0; i < rank; ++i) {
- if (extractStridesOfr[i] != reinterpretStridesOfr[i])
- return failure();
- }
+ // Second, check the sizes.
+ if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
+ op.getConstifiedMixedSizes()))
+ return false;
- // Second, check the sizes.
- assert(extractStridedMetadata.getSizes().size() ==
- op.getMixedSizes().size() &&
- "Strides and sizes rank must match");
- SmallVector<OpFoldResult> extractSizesOfr =
- extractStridedMetadata.getConstifiedMixedSizes();
- SmallVector<OpFoldResult> reinterpretSizesOfr =
- op.getConstifiedMixedSizes();
- for (unsigned i = 0; i < rank; ++i) {
- if (extractSizesOfr[i] != reinterpretSizesOfr[i])
- return failure();
+ // Finally, check the offset.
+ assert(op.getMixedOffsets().size() == 1 &&
+ "reinterpret_cast with more than one offset should have been "
+ "rejected by the verifier");
+ return extractStridedMetadata.getConstifiedMixedOffset() ==
+ op.getConstifiedMixedOffset();
+ };
+
+ if (!isReinterpretCastNoop()) {
+ // If the extract_strided_metadata / reinterpret_cast pair can't be
+ // completely folded, then we could fold the input of the
+ // extract_strided_metadata into the input of the reinterpret_cast
+ // input. For some cases (e.g., static dimensions) the
+ // the extract_strided_metadata is eliminated by dead code elimination.
+ //
+ // reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).
+ //
+ // We can always fold the input of a extract_strided_metadata operator
+ // to the input of a reinterpret_cast operator, because they point to
+ // the same memory. Note that the reinterpret_cast does not use the
+ // layout of its input memref, only its base memory pointer which is
+ // the same as the base pointer returned by the extract_strided_metadata
+ // operator and the base pointer of the extract_strided_metadata memref
+ // input.
+ rewriter.modifyOpInPlace(op, [&]() {
+ op.getSourceMutable().assign(extractStridedMetadata.getSource());
+ });
+ return success();
}
- // Finally, check the offset.
- assert(op.getMixedOffsets().size() == 1 &&
- "reinterpret_cast with more than one offset should have been "
- "rejected by the verifier");
- OpFoldResult extractOffsetOfr =
- extractStridedMetadata.getConstifiedMixedOffset();
- OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
- if (extractOffsetOfr != reinterpretOffsetOfr)
- return failure();
// At this point, we know that the back and forth between extract strided
// metadata and reinterpret cast is a noop. However, the final type of the
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 5d8a7d3f64e8f..e7cee7cd85426 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -952,8 +952,7 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
-// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
+// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
// CHECK: return %[[RES]]
func.func @reinterpret_of_extract_strided_metadata_w_
diff erent_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
@@ -969,8 +968,7 @@ func.func @reinterpret_of_extract_strided_metadata_w_
diff erent_stride(%arg0 : me
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
-// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
+// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
// CHECK: return %[[RES]]
func.func @reinterpret_of_extract_strided_metadata_w_
diff erent_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
More information about the Mlir-commits
mailing list