[Mlir-commits] [mlir] 5083e80 - Folding extract_strided_metadata input into reinterpret_cast (#134845)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 9 07:50:22 PDT 2025


Author: ivangarcia44
Date: 2025-04-09T16:50:16+02:00
New Revision: 5083e80c14a5c1f0ab40b5df95771ebbdda1adb2

URL: https://github.com/llvm/llvm-project/commit/5083e80c14a5c1f0ab40b5df95771ebbdda1adb2
DIFF: https://github.com/llvm/llvm-project/commit/5083e80c14a5c1f0ab40b5df95771ebbdda1adb2.diff

LOG: Folding extract_strided_metadata input into reinterpret_cast (#134845)

We can always fold the input of a extract_strided_metadata operator to
the input of a reinterpret_cast operator, because they point to the same
memory. Note that the reinterpret_cast does not use the layout of its
input memref, only its base memory pointer which is the same as the base
pointer returned by the extract_strided_metadata operator and the base
pointer of the extract_strided_metadata memref input.

Operations like expand_shape, collapse_shape, and subview are lowered to
a pair of extract_strided_metadata and reinterpret_cast like this:
      
%base_buffer, %offset, %sizes:2, %strides:2 =
memref.extract_strided_metadata %input_memref :
memref<ID1x...xIDNxBaseType> -> memref<f32>, index, index, index, index,
index

%reinterpret_cast = memref.reinterpret_cast %base_buffer to offset:
[%o1], sizes: [%d1,...,%dN], strides: [%s1,...,%N] : memref<f32> to
memref<OD1x...xODNxBaseType >

In many cases the input of the extract_strided_metadata input can be
passed directly into the input of the reinterpret_cast operation like
this (see how %base_buffer is replaced by %input_memref in the
reinterpret_cast above and the input type is updated):

%base_buffer, %offset, %sizes:2, %strides:2 =
memref.extract_strided_metadata %input_memref :
memref<ID1x...xIDNxBaseType> -> memref<f32>, index, index, index, index,
index
%reinterpret_cast = memref.reinterpret_cast %input_memref to offset:
[%o1], sizes: [%d1,...,%dN], strides: [%s1,...,%N] :
memref<ID1x...xIDNxBaseType> to memref<OD1x...xODNxBaseType >

When dealing with static dimensions, the extract_strided_metatdata will
become deadcode and we end up only with a reinterpret_cast:

%reinterpret_cast = memref.reinterpret_cast %input_memref to offset:
[%o1], sizes: [%d1,...,%dN], strides: [%s1,...,%N] :
memref<ID1x...xIDNxBaseType> to memref<OD1x...xODNxBaseType >

Note that reinterpret_cast only reads the base memory pointer from the
input memref (%input_memref above), which is equivalent to the
%base_buffer returned by the extract_strided_metadata operation. Hence
it is legal always to use the extract_strided_metadata input memref
directly in the reinterpret_cast. Note that since this is a pointer,
this operation is legal even when the base pointer values are modified
between the operation pair.

@matthias-springer 
@joker-eph 
@sahas3
@Hanumanth04
@dixinzhou
@rafaelubalmw

---------

Co-authored-by: Ivan Garcia <igarcia at vdi-ah2ddp-178.dhcp.mathworks.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/MemRef/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 123666848f83a..63f5251398716 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1124,7 +1124,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
         }
       } // else dim.getIndex is a block argument to reshape->getBlock and
         // dominates reshape
-    }   // Check condition 2
+    } // Check condition 2
     else if (dim->getBlock() != reshape->getBlock() &&
              !dim.getIndex().getParentRegion()->isProperAncestor(
                  reshape->getParentRegion())) {
@@ -2034,6 +2034,11 @@ namespace {
 /// ```
 /// Because we know that `offset`and `c0` will hold 0
 /// and `c4` will hold 4.
+///
+/// If the pattern above does not match, the input of the
+/// extract_strided_metadata is always folded into the input of the
+/// reinterpret_cast operator. This allows for dead code elimination to get rid
+/// of the extract_strided_metadata in some cases.
 struct ReinterpretCastOpExtractStridedMetadataFolder
     : public OpRewritePattern<ReinterpretCastOp> {
 public:
@@ -2045,44 +2050,49 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
         op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
     if (!extractStridedMetadata)
       return failure();
+
     // Check if the reinterpret cast reconstructs a memref with the exact same
     // properties as the extract strided metadata.
+    auto isReinterpretCastNoop = [&]() -> bool {
+      // First, check that the strides are the same.
+      if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
+                       op.getConstifiedMixedStrides()))
+        return false;
 
-    // First, check that the strides are the same.
-    SmallVector<OpFoldResult> extractStridesOfr =
-        extractStridedMetadata.getConstifiedMixedStrides();
-    SmallVector<OpFoldResult> reinterpretStridesOfr =
-        op.getConstifiedMixedStrides();
-    if (extractStridesOfr.size() != reinterpretStridesOfr.size())
-      return failure();
-
-    unsigned rank = op.getType().getRank();
-    for (unsigned i = 0; i < rank; ++i) {
-      if (extractStridesOfr[i] != reinterpretStridesOfr[i])
-        return failure();
-    }
+      // Second, check the sizes.
+      if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
+                       op.getConstifiedMixedSizes()))
+          return false;
 
-    // Second, check the sizes.
-    assert(extractStridedMetadata.getSizes().size() ==
-               op.getMixedSizes().size() &&
-           "Strides and sizes rank must match");
-    SmallVector<OpFoldResult> extractSizesOfr =
-        extractStridedMetadata.getConstifiedMixedSizes();
-    SmallVector<OpFoldResult> reinterpretSizesOfr =
-        op.getConstifiedMixedSizes();
-    for (unsigned i = 0; i < rank; ++i) {
-      if (extractSizesOfr[i] != reinterpretSizesOfr[i])
-        return failure();
+      // Finally, check the offset.
+      assert(op.getMixedOffsets().size() == 1 &&
+             "reinterpret_cast with more than one offset should have been "
+             "rejected by the verifier");
+      return extractStridedMetadata.getConstifiedMixedOffset() ==
+             op.getConstifiedMixedOffset();
+    };
+
+    if (!isReinterpretCastNoop()) {
+      // If the extract_strided_metadata / reinterpret_cast pair can't be
+      // completely folded, then we could fold the input of the
+      // extract_strided_metadata into the input of the reinterpret_cast
+      // input. For some cases (e.g., static dimensions) the
+      // the extract_strided_metadata is eliminated by dead code elimination.
+      //
+      // reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).
+      //
+      // We can always fold the input of a extract_strided_metadata operator
+      // to the input of a reinterpret_cast operator, because they point to
+      // the same memory. Note that the reinterpret_cast does not use the
+      // layout of its input memref, only its base memory pointer which is
+      // the same as the base pointer returned by the extract_strided_metadata
+      // operator and the base pointer of the extract_strided_metadata memref
+      // input.
+      rewriter.modifyOpInPlace(op, [&]() {
+        op.getSourceMutable().assign(extractStridedMetadata.getSource());
+      });
+      return success();
     }
-    // Finally, check the offset.
-    assert(op.getMixedOffsets().size() == 1 &&
-           "reinterpret_cast with more than one offset should have been "
-           "rejected by the verifier");
-    OpFoldResult extractOffsetOfr =
-        extractStridedMetadata.getConstifiedMixedOffset();
-    OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
-    if (extractOffsetOfr != reinterpretOffsetOfr)
-      return failure();
 
     // At this point, we know that the back and forth between extract strided
     // metadata and reinterpret cast is a noop. However, the final type of the

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 5d8a7d3f64e8f..e7cee7cd85426 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -952,8 +952,7 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x
 //  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
-//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
 //       CHECK: return %[[RES]]
 func.func @reinterpret_of_extract_strided_metadata_w_
diff erent_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
   %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
@@ -969,8 +968,7 @@ func.func @reinterpret_of_extract_strided_metadata_w_
diff erent_stride(%arg0 : me
 //   CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
 //   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
-//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
 //       CHECK: return %[[RES]]
 func.func @reinterpret_of_extract_strided_metadata_w_
diff erent_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
   %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index


        


More information about the Mlir-commits mailing list