[Mlir-commits] [mlir] [mlir][memref] Fold extract_strided_metadata(cast(x)) into extract_strided_metadata(x) (PR #164585)
Ming Yan
llvmlistbot at llvm.org
Wed Oct 22 02:35:41 PDT 2025
https://github.com/NexMing updated https://github.com/llvm/llvm-project/pull/164585
>From 4b37b2fdc5a94bd7e6c4ea73938ced7f4ebe105c Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Wed, 22 Oct 2025 17:14:42 +0800
Subject: [PATCH 1/2] [mlir][memref] Fold extract_strided_metadata(cast(x))
into extract_strided_metadata(x)
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 7 ++
.../Transforms/ExpandStridedMetadata.cpp | 87 -------------------
mlir/test/Dialect/MemRef/canonicalize.mlir | 15 ++++
3 files changed, 22 insertions(+), 87 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 94947b760251e..c06a48ee4b87c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1437,6 +1437,13 @@ ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
atLeastOneReplacement |= replaceConstantUsesOf(
builder, getLoc(), getStrides(), getConstifiedMixedStrides());
+ // extract_strided_metadata(cast(x)) -> extract_strided_metadata(x).
+ if (auto prev = getSource().getDefiningOp<CastOp>())
+ if (isa<MemRefType>(prev.getSource().getType())) {
+ getSourceMutable().assign(prev.getSource());
+ atLeastOneReplacement = true;
+ }
+
return success(atLeastOneReplacement);
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index d35566a9c0d29..bd02516d5b527 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -1033,91 +1033,6 @@ class ExtractStridedMetadataOpReinterpretCastFolder
}
};
-/// Replace `base, offset, sizes, strides =
-/// extract_strided_metadata(
-/// cast(src) to dstTy)`
-/// With
-/// ```
-/// base, ... = extract_strided_metadata(src)
-/// offset = !dstTy.srcOffset.isDynamic()
-/// ? dstTy.srcOffset
-/// : extract_strided_metadata(src).offset
-/// sizes = for each srcSize in dstTy.srcSizes:
-/// !srcSize.isDynamic()
-/// ? srcSize
-// : extract_strided_metadata(src).sizes[i]
-/// strides = for each srcStride in dstTy.srcStrides:
-/// !srcStrides.isDynamic()
-/// ? srcStrides
-/// : extract_strided_metadata(src).strides[i]
-/// ```
-///
-/// In other words, consume the `cast` and apply its effects
-/// on the offset, sizes, and strides or compute them directly from `src`.
-class ExtractStridedMetadataOpCastFolder
- : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult
- matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
- PatternRewriter &rewriter) const override {
- Value source = extractStridedMetadataOp.getSource();
- auto castOp = source.getDefiningOp<memref::CastOp>();
- if (!castOp)
- return failure();
-
- Location loc = extractStridedMetadataOp.getLoc();
- // Check if the source is suitable for extract_strided_metadata.
- SmallVector<Type> inferredReturnTypes;
- if (failed(extractStridedMetadataOp.inferReturnTypes(
- rewriter.getContext(), loc, {castOp.getSource()},
- /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
- inferredReturnTypes)))
- return rewriter.notifyMatchFailure(castOp,
- "cast source's type is incompatible");
-
- auto memrefType = cast<MemRefType>(source.getType());
- unsigned rank = memrefType.getRank();
- SmallVector<OpFoldResult> results;
- results.resize_for_overwrite(rank * 2 + 2);
-
- auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
- rewriter, loc, castOp.getSource());
-
- // Register the base_buffer.
- results[0] = newExtractStridedMetadata.getBaseBuffer();
-
- auto getConstantOrValue = [&rewriter](int64_t constant,
- OpFoldResult ofr) -> OpFoldResult {
- return ShapedType::isStatic(constant)
- ? OpFoldResult(rewriter.getIndexAttr(constant))
- : ofr;
- };
-
- auto [sourceStrides, sourceOffset] = memrefType.getStridesAndOffset();
- assert(sourceStrides.size() == rank && "unexpected number of strides");
-
- // Register the new offset.
- results[1] =
- getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());
-
- const unsigned sizeStartIdx = 2;
- const unsigned strideStartIdx = sizeStartIdx + rank;
- ArrayRef<int64_t> sourceSizes = memrefType.getShape();
-
- SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes();
- SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides();
- for (unsigned i = 0; i < rank; ++i) {
- results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
- results[strideStartIdx + i] =
- getConstantOrValue(sourceStrides[i], strides[i]);
- }
- rewriter.replaceOp(extractStridedMetadataOp,
- getValueOrCreateConstantIndexOp(rewriter, loc, results));
- return success();
- }
-};
-
/// Replace `base, offset, sizes, strides = extract_strided_metadata(
/// memory_space_cast(src) to dstTy)`
/// with
@@ -1209,7 +1124,6 @@ void memref::populateExpandStridedMetadataPatterns(
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
ExtractStridedMetadataOpReinterpretCastFolder,
ExtractStridedMetadataOpSubviewFolder,
- ExtractStridedMetadataOpCastFolder,
ExtractStridedMetadataOpMemorySpaceCastFolder,
ExtractStridedMetadataOpAssumeAlignmentFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
@@ -1226,7 +1140,6 @@ void memref::populateResolveExtractStridedMetadataPatterns(
ExtractStridedMetadataOpSubviewFolder,
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
ExtractStridedMetadataOpReinterpretCastFolder,
- ExtractStridedMetadataOpCastFolder,
ExtractStridedMetadataOpMemorySpaceCastFolder,
ExtractStridedMetadataOpAssumeAlignmentFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 7160b52af6353..bab979bb86959 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -901,6 +901,21 @@ func.func @scope_merge_without_terminator() {
// -----
+// CHECK-LABEL: func @extract_strided_metadata_of_cast
+// CHECK-SAME: (%[[ARG:.*]]: memref<?xf32>)
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[ARG]]
+// CHECK: return %[[BASE]], %[[C0]], %[[C4]], %[[C1]]
+func.func @extract_strided_metadata_of_cast(%arg0: memref<?xf32>) -> (memref<f32>, index, index, index) {
+ %cast = memref.cast %arg0 : memref<?xf32> to memref<4xf32, strided<[?]>>
+ %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %cast : memref<4xf32, strided<[?]>> -> memref<f32>, index, index, index
+ return %base_buffer, %offset, %sizes, %strides : memref<f32>, index, index, index
+}
+
+// -----
+
// CHECK-LABEL: func @reinterpret_noop
// CHECK-SAME: (%[[ARG:.*]]: memref<2x3x4xf32>)
// CHECK-NEXT: return %[[ARG]]
>From 4d413ce780ff0e5ff82f11c199314bda4083d12f Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Wed, 22 Oct 2025 17:35:10 +0800
Subject: [PATCH 2/2] move testcases
---
mlir/test/Dialect/MemRef/canonicalize.mlir | 131 ++++++++++++++++--
.../MemRef/expand-strided-metadata.mlir | 127 -----------------
2 files changed, 121 insertions(+), 137 deletions(-)
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index bab979bb86959..313090272ef90 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -901,17 +901,128 @@ func.func @scope_merge_without_terminator() {
// -----
+// Check that we simplify extract_strided_metadata of cast
+// when the source of the cast is compatible with what
+// `extract_strided_metadata`s accept.
+//
+// When we apply the transformation the resulting offset, sizes and strides
+// should come straight from the inputs of the cast.
+// Additionally the folder on extract_strided_metadata should propagate the
+// static information.
+//
// CHECK-LABEL: func @extract_strided_metadata_of_cast
-// CHECK-SAME: (%[[ARG:.*]]: memref<?xf32>)
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[C4:.*]] = arith.constant 4 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[ARG]]
-// CHECK: return %[[BASE]], %[[C0]], %[[C4]], %[[C1]]
-func.func @extract_strided_metadata_of_cast(%arg0: memref<?xf32>) -> (memref<f32>, index, index, index) {
- %cast = memref.cast %arg0 : memref<?xf32> to memref<4xf32, strided<[?]>>
- %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %cast : memref<4xf32, strided<[?]>> -> memref<f32>, index, index, index
- return %base_buffer, %offset, %sizes, %strides : memref<f32>, index, index, index
+// CHECK-SAME: %[[ARG:.*]]: memref<3x?xi32, strided<[4, ?], offset: ?>>)
+//
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
+//
+// CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[C3]], %[[DYN_SIZES]]#1, %[[C4]], %[[DYN_STRIDES]]#1
+func.func @extract_strided_metadata_of_cast(
+ %arg : memref<3x?xi32, strided<[4, ?], offset:?>>)
+ -> (memref<i32>, index,
+ index, index,
+ index, index) {
+
+ %cast =
+ memref.cast %arg :
+ memref<3x?xi32, strided<[4, ?], offset: ?>> to
+ memref<?x?xi32, strided<[?, ?], offset: ?>>
+
+ %base, %base_offset, %sizes:2, %strides:2 =
+ memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
+ -> memref<i32>, index,
+ index, index,
+ index, index
+
+ return %base, %base_offset,
+ %sizes#0, %sizes#1,
+ %strides#0, %strides#1 :
+ memref<i32>, index,
+ index, index,
+ index, index
+}
+
+// -----
+
+// Check that we simplify extract_strided_metadata of cast
+// when the source of the cast is compatible with what
+// `extract_strided_metadata`s accept.
+//
+// Same as extract_strided_metadata_of_cast but with constant sizes and strides
+// in the destination type.
+//
+// CHECK-LABEL: func @extract_strided_metadata_of_cast_w_csts
+// CHECK-SAME: %[[ARG:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>)
+//
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C18:.*]] = arith.constant 18 : index
+// CHECK-DAG: %[[C25:.*]] = arith.constant 25 : index
+// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
+//
+// CHECK: return %[[BASE]], %[[C25]], %[[C4]], %[[DYN_SIZES]]#1, %[[DYN_STRIDES]]#0, %[[C18]]
+func.func @extract_strided_metadata_of_cast_w_csts(
+ %arg : memref<?x?xi32, strided<[?, ?], offset:?>>)
+ -> (memref<i32>, index,
+ index, index,
+ index, index) {
+
+ %cast =
+ memref.cast %arg :
+ memref<?x?xi32, strided<[?, ?], offset: ?>> to
+ memref<4x?xi32, strided<[?, 18], offset: 25>>
+
+ %base, %base_offset, %sizes:2, %strides:2 =
+ memref.extract_strided_metadata %cast:memref<4x?xi32, strided<[?, 18], offset: 25>>
+ -> memref<i32>, index,
+ index, index,
+ index, index
+
+ return %base, %base_offset,
+ %sizes#0, %sizes#1,
+ %strides#0, %strides#1 :
+ memref<i32>, index,
+ index, index,
+ index, index
+}
+
+// -----
+
+// Check that we don't simplify extract_strided_metadata of
+// cast when the source of the cast is unranked.
+// Unranked memrefs cannot feed into extract_strided_metadata operations.
+// Note: Technically we could still fold the sizes and strides.
+//
+// CHECK-LABEL: func @extract_strided_metadata_of_cast_unranked
+// CHECK-SAME: %[[ARG:.*]]: memref<*xi32>)
+//
+// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] :
+// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]]
+//
+// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1
+func.func @extract_strided_metadata_of_cast_unranked(
+ %arg : memref<*xi32>)
+ -> (memref<i32>, index,
+ index, index,
+ index, index) {
+
+ %cast =
+ memref.cast %arg :
+ memref<*xi32> to
+ memref<?x?xi32, strided<[?, ?], offset: ?>>
+
+ %base, %base_offset, %sizes:2, %strides:2 =
+ memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
+ -> memref<i32>, index,
+ index, index,
+ index, index
+
+ return %base, %base_offset,
+ %sizes#0, %sizes#1,
+ %strides#0, %strides#1 :
+ memref<i32>, index,
+ index, index,
+ index, index
}
// -----
diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index 1e6b0111fa4c7..18cdfb73f6ba8 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -1376,133 +1376,6 @@ func.func @extract_strided_metadata_of_get_global_with_offset()
memref<i32>, index, index, index, index, index
}
-// -----
-
-// Check that we simplify extract_strided_metadata of cast
-// when the source of the cast is compatible with what
-// `extract_strided_metadata`s accept.
-//
-// When we apply the transformation the resulting offset, sizes and strides
-// should come straight from the inputs of the cast.
-// Additionally the folder on extract_strided_metadata should propagate the
-// static information.
-//
-// CHECK-LABEL: func @extract_strided_metadata_of_cast
-// CHECK-SAME: %[[ARG:.*]]: memref<3x?xi32, strided<[4, ?], offset: ?>>)
-//
-// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
-// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
-// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
-//
-// CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[C3]], %[[DYN_SIZES]]#1, %[[C4]], %[[DYN_STRIDES]]#1
-func.func @extract_strided_metadata_of_cast(
- %arg : memref<3x?xi32, strided<[4, ?], offset:?>>)
- -> (memref<i32>, index,
- index, index,
- index, index) {
-
- %cast =
- memref.cast %arg :
- memref<3x?xi32, strided<[4, ?], offset: ?>> to
- memref<?x?xi32, strided<[?, ?], offset: ?>>
-
- %base, %base_offset, %sizes:2, %strides:2 =
- memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
- -> memref<i32>, index,
- index, index,
- index, index
-
- return %base, %base_offset,
- %sizes#0, %sizes#1,
- %strides#0, %strides#1 :
- memref<i32>, index,
- index, index,
- index, index
-}
-
-// -----
-
-// Check that we simplify extract_strided_metadata of cast
-// when the source of the cast is compatible with what
-// `extract_strided_metadata`s accept.
-//
-// Same as extract_strided_metadata_of_cast but with constant sizes and strides
-// in the destination type.
-//
-// CHECK-LABEL: func @extract_strided_metadata_of_cast_w_csts
-// CHECK-SAME: %[[ARG:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>)
-//
-// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[C18:.*]] = arith.constant 18 : index
-// CHECK-DAG: %[[C25:.*]] = arith.constant 25 : index
-// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
-//
-// CHECK: return %[[BASE]], %[[C25]], %[[C4]], %[[DYN_SIZES]]#1, %[[DYN_STRIDES]]#0, %[[C18]]
-func.func @extract_strided_metadata_of_cast_w_csts(
- %arg : memref<?x?xi32, strided<[?, ?], offset:?>>)
- -> (memref<i32>, index,
- index, index,
- index, index) {
-
- %cast =
- memref.cast %arg :
- memref<?x?xi32, strided<[?, ?], offset: ?>> to
- memref<4x?xi32, strided<[?, 18], offset: 25>>
-
- %base, %base_offset, %sizes:2, %strides:2 =
- memref.extract_strided_metadata %cast:memref<4x?xi32, strided<[?, 18], offset: 25>>
- -> memref<i32>, index,
- index, index,
- index, index
-
- return %base, %base_offset,
- %sizes#0, %sizes#1,
- %strides#0, %strides#1 :
- memref<i32>, index,
- index, index,
- index, index
-}
-
-// -----
-
-// Check that we don't simplify extract_strided_metadata of
-// cast when the source of the cast is unranked.
-// Unranked memrefs cannot feed into extract_strided_metadata operations.
-// Note: Technically we could still fold the sizes and strides.
-//
-// CHECK-LABEL: func @extract_strided_metadata_of_cast_unranked
-// CHECK-SAME: %[[ARG:.*]]: memref<*xi32>)
-//
-// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] :
-// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]]
-//
-// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1
-func.func @extract_strided_metadata_of_cast_unranked(
- %arg : memref<*xi32>)
- -> (memref<i32>, index,
- index, index,
- index, index) {
-
- %cast =
- memref.cast %arg :
- memref<*xi32> to
- memref<?x?xi32, strided<[?, ?], offset: ?>>
-
- %base, %base_offset, %sizes:2, %strides:2 =
- memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
- -> memref<i32>, index,
- index, index,
- index, index
-
- return %base, %base_offset,
- %sizes#0, %sizes#1,
- %strides#0, %strides#1 :
- memref<i32>, index,
- index, index,
- index, index
-}
-
-
// -----
memref.global "private" @dynamicShmem : memref<0xf16,3>
More information about the Mlir-commits
mailing list