[Mlir-commits] [mlir] [mlir][memref] Fold extract_strided_metadata(cast(x)) into extract_strided_metadata(x) (PR #164585)

Ming Yan llvmlistbot at llvm.org
Wed Oct 22 02:23:04 PDT 2025


https://github.com/NexMing created https://github.com/llvm/llvm-project/pull/164585

None

>From 4b37b2fdc5a94bd7e6c4ea73938ced7f4ebe105c Mon Sep 17 00:00:00 2001
From: yanming <ming.yan at terapines.com>
Date: Wed, 22 Oct 2025 17:14:42 +0800
Subject: [PATCH] [mlir][memref] Fold extract_strided_metadata(cast(x)) into
 extract_strided_metadata(x)

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      |  7 ++
 .../Transforms/ExpandStridedMetadata.cpp      | 87 -------------------
 mlir/test/Dialect/MemRef/canonicalize.mlir    | 15 ++++
 3 files changed, 22 insertions(+), 87 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 94947b760251e..c06a48ee4b87c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1437,6 +1437,13 @@ ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
   atLeastOneReplacement |= replaceConstantUsesOf(
       builder, getLoc(), getStrides(), getConstifiedMixedStrides());
 
+  // extract_strided_metadata(cast(x)) -> extract_strided_metadata(x).
+  if (auto prev = getSource().getDefiningOp<CastOp>())
+    if (isa<MemRefType>(prev.getSource().getType())) {
+      getSourceMutable().assign(prev.getSource());
+      atLeastOneReplacement = true;
+    }
+
   return success(atLeastOneReplacement);
 }
 
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index d35566a9c0d29..bd02516d5b527 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -1033,91 +1033,6 @@ class ExtractStridedMetadataOpReinterpretCastFolder
   }
 };
 
-/// Replace `base, offset, sizes, strides =
-///              extract_strided_metadata(
-///                 cast(src) to dstTy)`
-/// With
-/// ```
-/// base, ... = extract_strided_metadata(src)
-/// offset = !dstTy.srcOffset.isDynamic()
-///            ? dstTy.srcOffset
-///            : extract_strided_metadata(src).offset
-/// sizes = for each srcSize in dstTy.srcSizes:
-///           !srcSize.isDynamic()
-///             ? srcSize
-//              : extract_strided_metadata(src).sizes[i]
-/// strides = for each srcStride in dstTy.srcStrides:
-///             !srcStrides.isDynamic()
-///               ? srcStrides
-///               : extract_strided_metadata(src).strides[i]
-/// ```
-///
-/// In other words, consume the `cast` and apply its effects
-/// on the offset, sizes, and strides or compute them directly from `src`.
-class ExtractStridedMetadataOpCastFolder
-    : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult
-  matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
-                  PatternRewriter &rewriter) const override {
-    Value source = extractStridedMetadataOp.getSource();
-    auto castOp = source.getDefiningOp<memref::CastOp>();
-    if (!castOp)
-      return failure();
-
-    Location loc = extractStridedMetadataOp.getLoc();
-    // Check if the source is suitable for extract_strided_metadata.
-    SmallVector<Type> inferredReturnTypes;
-    if (failed(extractStridedMetadataOp.inferReturnTypes(
-            rewriter.getContext(), loc, {castOp.getSource()},
-            /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
-            inferredReturnTypes)))
-      return rewriter.notifyMatchFailure(castOp,
-                                         "cast source's type is incompatible");
-
-    auto memrefType = cast<MemRefType>(source.getType());
-    unsigned rank = memrefType.getRank();
-    SmallVector<OpFoldResult> results;
-    results.resize_for_overwrite(rank * 2 + 2);
-
-    auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
-        rewriter, loc, castOp.getSource());
-
-    // Register the base_buffer.
-    results[0] = newExtractStridedMetadata.getBaseBuffer();
-
-    auto getConstantOrValue = [&rewriter](int64_t constant,
-                                          OpFoldResult ofr) -> OpFoldResult {
-      return ShapedType::isStatic(constant)
-                 ? OpFoldResult(rewriter.getIndexAttr(constant))
-                 : ofr;
-    };
-
-    auto [sourceStrides, sourceOffset] = memrefType.getStridesAndOffset();
-    assert(sourceStrides.size() == rank && "unexpected number of strides");
-
-    // Register the new offset.
-    results[1] =
-        getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());
-
-    const unsigned sizeStartIdx = 2;
-    const unsigned strideStartIdx = sizeStartIdx + rank;
-    ArrayRef<int64_t> sourceSizes = memrefType.getShape();
-
-    SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes();
-    SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides();
-    for (unsigned i = 0; i < rank; ++i) {
-      results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
-      results[strideStartIdx + i] =
-          getConstantOrValue(sourceStrides[i], strides[i]);
-    }
-    rewriter.replaceOp(extractStridedMetadataOp,
-                       getValueOrCreateConstantIndexOp(rewriter, loc, results));
-    return success();
-  }
-};
-
 /// Replace `base, offset, sizes, strides = extract_strided_metadata(
 ///      memory_space_cast(src) to dstTy)`
 /// with
@@ -1209,7 +1124,6 @@ void memref::populateExpandStridedMetadataPatterns(
                RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
                ExtractStridedMetadataOpReinterpretCastFolder,
                ExtractStridedMetadataOpSubviewFolder,
-               ExtractStridedMetadataOpCastFolder,
                ExtractStridedMetadataOpMemorySpaceCastFolder,
                ExtractStridedMetadataOpAssumeAlignmentFolder,
                ExtractStridedMetadataOpExtractStridedMetadataFolder>(
@@ -1226,7 +1140,6 @@ void memref::populateResolveExtractStridedMetadataPatterns(
                ExtractStridedMetadataOpSubviewFolder,
                RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
                ExtractStridedMetadataOpReinterpretCastFolder,
-               ExtractStridedMetadataOpCastFolder,
                ExtractStridedMetadataOpMemorySpaceCastFolder,
                ExtractStridedMetadataOpAssumeAlignmentFolder,
                ExtractStridedMetadataOpExtractStridedMetadataFolder>(
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 7160b52af6353..bab979bb86959 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -901,6 +901,21 @@ func.func @scope_merge_without_terminator() {
 
 // -----
 
+// CHECK-LABEL: func @extract_strided_metadata_of_cast
+//  CHECK-SAME: (%[[ARG:.*]]: memref<?xf32>)
+//       CHECK: %[[C0:.*]] = arith.constant 0 : index
+//       CHECK: %[[C4:.*]] = arith.constant 4 : index
+//       CHECK: %[[C1:.*]] = arith.constant 1 : index
+//       CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[ARG]]
+//       CHECK: return %[[BASE]], %[[C0]], %[[C4]], %[[C1]]
+func.func @extract_strided_metadata_of_cast(%arg0: memref<?xf32>) -> (memref<f32>, index, index, index) {
+  %cast = memref.cast %arg0 : memref<?xf32> to memref<4xf32, strided<[?]>>
+  %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %cast : memref<4xf32, strided<[?]>> -> memref<f32>, index, index, index
+  return %base_buffer, %offset, %sizes, %strides : memref<f32>, index, index, index
+}
+
+// -----
+
 // CHECK-LABEL: func @reinterpret_noop
 //  CHECK-SAME: (%[[ARG:.*]]: memref<2x3x4xf32>)
 //  CHECK-NEXT: return %[[ARG]]



More information about the Mlir-commits mailing list