[Mlir-commits] [mlir] Folding extract_strided_metadata input into reinterpret_cast on constant layout (PR #134845)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 9 05:15:00 PDT 2025


https://github.com/ivangarcia44 updated https://github.com/llvm/llvm-project/pull/134845

>From dd31b5f193e7e1f1df8a0c6da3a5c304f85da990 Mon Sep 17 00:00:00 2001
From: Ivan Garcia <igarcia at vdi-ah2ddp-178.dhcp.mathworks.com>
Date: Tue, 8 Apr 2025 07:37:41 -0400
Subject: [PATCH 1/3] Folding extract_strided_metadata input into
 reinterpret_cast on constant layout.

---
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       |  3 ++
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      | 39 ++++++++++++++++++-
 .../expand-then-convert-to-llvm.mlir          | 22 +++++------
 .../MemRef/expand-strided-metadata.mlir       |  6 +--
 4 files changed, 53 insertions(+), 17 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 3edc2433c85ea..990a282771502 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1440,6 +1440,9 @@ def MemRef_ReinterpretCastOp
     SmallVector<OpFoldResult> getConstifiedMixedStrides();
     /// Similar to `getConstifiedMixedSizes` but for the offset.
     OpFoldResult getConstifiedMixedOffset();
+    /// Returns true if the reinterpret cast operation's offset, stride, and 
+    /// size are all constant.
+    bool isLayoutConstant();
   }];
 
   let hasFolder = 1;
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 123666848f83a..629d0d8d425b1 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1124,7 +1124,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
         }
       } // else dim.getIndex is a block argument to reshape->getBlock and
         // dominates reshape
-    }   // Check condition 2
+    } // Check condition 2
     else if (dim->getBlock() != reshape->getBlock() &&
              !dim.getIndex().getParentRegion()->isProperAncestor(
                  reshape->getParentRegion())) {
@@ -1948,6 +1948,27 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
     if (auto prev = src.getDefiningOp<CastOp>())
       return prev.getSource();
 
+    // reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).
+    //
+    // We can always fold the input of a extract_strided_metadata operator
+    // to the input of a reinterpret_cast operator, because they point to
+    // the same memory. Note that the reinterpret_cast does not use the
+    // layout of its input memref, only its base memory pointer which is
+    // the same as the base pointer returned by the extract_strided_metadata
+    // operator and the base pointer of the extract_strided_metadata memref
+    // input. This folding is only profitable when the reinterpret_cast
+    // layout is constant, because the extract_strided_metadata gets
+    // eliminated by dead code elimination. For non-constant folding we don’t
+    // get the extract_strided_metadata node eliminated and one of the LLVM
+    // tests regress in performance because the folding gets in the way of
+    // another optimization. For this reason the folding is only done on
+    // constant layout.
+    if (auto prev = src.getDefiningOp<ExtractStridedMetadataOp>()) {
+      if (isLayoutConstant()) {
+        return prev.getSource();
+      }
+    }
+
     // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
     // are 0.
     if (auto prev = src.getDefiningOp<SubViewOp>())
@@ -1973,6 +1994,22 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
   return nullptr;
 }
 
+bool ReinterpretCastOp::isLayoutConstant() {
+  if (llvm::all_of(
+          getOffsets(),
+          [](OpFoldResult val) { return isConstantIntValue(val, 0); }) &&
+      llvm::all_of(
+          getStrides(),
+          [](OpFoldResult val) { return isConstantIntValue(val, 0); }) &&
+      llvm::all_of(getSizes(), [](OpFoldResult val) {
+        return isConstantIntValue(val, 0);
+      })) {
+    return true;
+  } else {
+    return false;
+  }
+}
+
 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
   SmallVector<OpFoldResult> values = getMixedSizes();
   constifyIndexValues(values, getType(), getContext(), getConstantSizes,
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index fe91d26d5a251..f1cb9c9f165be 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -195,10 +195,10 @@ func.func @subview_const_stride(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>
 func.func @subview_const_stride_and_offset(%0 : memref<64x8xf32, strided<[8, 1], offset: 0>>) -> memref<62x3xf32, strided<[8, 1], offset: 2>> {
   // The last "insertvalue" that populates the memref descriptor from the function arguments.
   // CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
+  // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 
   // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-  // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[BASE_ALIGNED]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[CST_OFF:.*]] = llvm.mlir.constant(2 : index) : i64
@@ -265,11 +265,11 @@ func.func @subview_mixed_static_dynamic(%0 : memref<64x4xf32, strided<[4, 1], of
 // CHECK:         %[[MEM:.*]]: memref<{{.*}}>,
 func.func @subview_leading_operands(%0 : memref<5x3xf32>, %1: memref<5x?xf32>) -> memref<3x3xf32, strided<[3, 1], offset: 6>> {
   // CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
+  // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // Alloc ptr
   // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64
   // Aligned ptr
   // CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64
-  // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[BASE_ALIGNED]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // Offset
@@ -331,9 +331,9 @@ func.func @subview_leading_operands_dynamic(%0 : memref<5x?xf32>) -> memref<3x?x
 // CHECK:         %[[MEM:.*]]: memref
 func.func @subview_rank_reducing_leading_operands(%0 : memref<5x3xf32>) -> memref<3xf32, strided<[1], offset: 3>> {
   // CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
+  // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
   // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64
   // CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64
-  // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
   // Alloc ptr
   // CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
   // Aligned ptr
@@ -356,9 +356,9 @@ func.func @subview_rank_reducing_leading_operands(%0 : memref<5x3xf32>) -> memre
 // CHECK-SAME: (%[[MEM:.*]]: memref<7xf32>)
 func.func @subview_negative_stride(%arg0 : memref<7xf32>) -> memref<7xf32, strided<[-1], offset: 6>> {
   // CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
+  // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
   // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64
   // CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64
-  // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
   // CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
   // CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[BASE_ALIGNED]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
   // CHECK: %[[CST_OFF0:.*]] = llvm.mlir.constant(6 : index) : i64
@@ -384,12 +384,12 @@ func.func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf
 // CHECK-LABEL: func @collapse_shape_static
 // CHECK-SAME: %[[ARG:.*]]: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> {
 // CHECK:           %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x3x4x1x5xf32> to !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 // CHECK:           %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
 // CHECK:           %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
-// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 // CHECK:           %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 // CHECK:           %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
 // CHECK:           %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 // CHECK:           %[[C3:.*]] = llvm.mlir.constant(3 : index) : i64
 // CHECK:           %[[DESC3:.*]] = llvm.insertvalue %[[C3]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
@@ -458,12 +458,12 @@ func.func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32
 // CHECK-LABEL: func @expand_shape_static
 // CHECK-SAME: %[[ARG:.*]]: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
 // CHECK:           %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<3x4x5xf32> to !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
 // CHECK:           %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64,
 // CHECK:           %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64,
-// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
-// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
 // CHECK:           %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
 // CHECK:           %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
 // CHECK:           %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
 // CHECK:           %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
 // CHECK:           %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
@@ -494,9 +494,9 @@ func.func @collapse_shape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref<f32>
 // CHECK-LABEL:   func.func @collapse_shape_fold_zero_dim(
 // CHECK-SAME:                                            %[[ARG:.*]]: memref<1x1xf32>) -> memref<f32> {
 // CHECK:           %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x1xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
 // CHECK:           %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64,
 // CHECK:           %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64,
-// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
 // CHECK:           %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64)>
 // CHECK:           %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64)>
 // CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
@@ -515,12 +515,12 @@ func.func @expand_shape_zero_dim(%arg0 : memref<f32>) -> memref<1x1xf32> {
 // CHECK-LABEL:   func.func @expand_shape_zero_dim(
 // CHECK-SAME:                                     %[[ARG:.*]]: memref<f32>) -> memref<1x1xf32> {
 // CHECK:           %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<f32> to !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64)>
 // CHECK:           %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64)>
-// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
-// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
 // CHECK:           %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
 // CHECK:           %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index 1e6b0111fa4c7..da74c73ccd7a5 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -975,11 +975,7 @@ func.func @simplify_collapse(%arg : memref<?x?x4x?x6x7xi32>)
 //
 // CHECK-LABEL: func @simplify_collapse_with_dim_of_size1(
 //  CHECK-SAME: %[[ARG:.*]]: memref<3x1xf32, strided<[2, 1]>>,
-//
-//       CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<3x1xf32, strided<[2, 1]>>
-//
-//
-//       CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [3], strides: [2]
+//       CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [3], strides: [2]
 func.func @simplify_collapse_with_dim_of_size1(%arg0: memref<3x1xf32, strided<[2,1]>>, %arg1: memref<3xf32>) {
 
   %collapse_shape = memref.collapse_shape %arg0 [[0, 1]] :

>From 94f4b167a99cd1435b5abc88654edacc2dd7edbf Mon Sep 17 00:00:00 2001
From: Ivan Garcia <igarcia at vdi-ah2ddp-178.dhcp.mathworks.com>
Date: Tue, 8 Apr 2025 16:37:36 -0400
Subject: [PATCH 2/3] Fixing issue found by Matthias Springer and
 improving/simplifying change.

---
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       |   3 -
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      | 121 ++++++++----------
 .../expand-then-convert-to-llvm.mlir          |  22 ++--
 mlir/test/Dialect/MemRef/canonicalize.mlir    |   6 +-
 .../MemRef/expand-strided-metadata.mlir       |   6 +-
 5 files changed, 73 insertions(+), 85 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 990a282771502..3edc2433c85ea 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1440,9 +1440,6 @@ def MemRef_ReinterpretCastOp
     SmallVector<OpFoldResult> getConstifiedMixedStrides();
     /// Similar to `getConstifiedMixedSizes` but for the offset.
     OpFoldResult getConstifiedMixedOffset();
-    /// Returns true if the reinterpret cast operation's offset, stride, and 
-    /// size are all constant.
-    bool isLayoutConstant();
   }];
 
   let hasFolder = 1;
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 629d0d8d425b1..895baecd8c42b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1948,27 +1948,6 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
     if (auto prev = src.getDefiningOp<CastOp>())
       return prev.getSource();
 
-    // reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).
-    //
-    // We can always fold the input of a extract_strided_metadata operator
-    // to the input of a reinterpret_cast operator, because they point to
-    // the same memory. Note that the reinterpret_cast does not use the
-    // layout of its input memref, only its base memory pointer which is
-    // the same as the base pointer returned by the extract_strided_metadata
-    // operator and the base pointer of the extract_strided_metadata memref
-    // input. This folding is only profitable when the reinterpret_cast
-    // layout is constant, because the extract_strided_metadata gets
-    // eliminated by dead code elimination. For non-constant folding we don’t
-    // get the extract_strided_metadata node eliminated and one of the LLVM
-    // tests regress in performance because the folding gets in the way of
-    // another optimization. For this reason the folding is only done on
-    // constant layout.
-    if (auto prev = src.getDefiningOp<ExtractStridedMetadataOp>()) {
-      if (isLayoutConstant()) {
-        return prev.getSource();
-      }
-    }
-
     // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
     // are 0.
     if (auto prev = src.getDefiningOp<SubViewOp>())
@@ -1994,22 +1973,6 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
   return nullptr;
 }
 
-bool ReinterpretCastOp::isLayoutConstant() {
-  if (llvm::all_of(
-          getOffsets(),
-          [](OpFoldResult val) { return isConstantIntValue(val, 0); }) &&
-      llvm::all_of(
-          getStrides(),
-          [](OpFoldResult val) { return isConstantIntValue(val, 0); }) &&
-      llvm::all_of(getSizes(), [](OpFoldResult val) {
-        return isConstantIntValue(val, 0);
-      })) {
-    return true;
-  } else {
-    return false;
-  }
-}
-
 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
   SmallVector<OpFoldResult> values = getMixedSizes();
   constifyIndexValues(values, getType(), getContext(), getConstantSizes,
@@ -2071,6 +2034,11 @@ namespace {
 /// ```
 /// Because we know that `offset`and `c0` will hold 0
 /// and `c4` will hold 4.
+///
+/// If the pattern above does not match, the input of the extract_strided_metadata
+/// is always folded into the input of the reinterpret_cast operator. This allows
+/// for dead code elimination to get rid of the extract_strided_metadata in some
+/// cases.
 struct ReinterpretCastOpExtractStridedMetadataFolder
     : public OpRewritePattern<ReinterpretCastOp> {
 public:
@@ -2082,44 +2050,65 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
         op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
     if (!extractStridedMetadata)
       return failure();
+
     // Check if the reinterpret cast reconstructs a memref with the exact same
     // properties as the extract strided metadata.
-
-    // First, check that the strides are the same.
     SmallVector<OpFoldResult> extractStridesOfr =
         extractStridedMetadata.getConstifiedMixedStrides();
     SmallVector<OpFoldResult> reinterpretStridesOfr =
         op.getConstifiedMixedStrides();
-    if (extractStridesOfr.size() != reinterpretStridesOfr.size())
-      return failure();
+    auto isReinterpretCastNoop = [&]() -> bool {
+      // First, check that the strides are the same.
+      if (extractStridesOfr.size() != reinterpretStridesOfr.size())
+        return false;
 
-    unsigned rank = op.getType().getRank();
-    for (unsigned i = 0; i < rank; ++i) {
-      if (extractStridesOfr[i] != reinterpretStridesOfr[i])
-        return failure();
-    }
+      unsigned rank = op.getType().getRank();
+      for (unsigned i = 0; i < rank; ++i) {
+        if (extractStridesOfr[i] != reinterpretStridesOfr[i])
+          return false;
+      }
 
-    // Second, check the sizes.
-    assert(extractStridedMetadata.getSizes().size() ==
-               op.getMixedSizes().size() &&
-           "Strides and sizes rank must match");
-    SmallVector<OpFoldResult> extractSizesOfr =
-        extractStridedMetadata.getConstifiedMixedSizes();
-    SmallVector<OpFoldResult> reinterpretSizesOfr =
-        op.getConstifiedMixedSizes();
-    for (unsigned i = 0; i < rank; ++i) {
-      if (extractSizesOfr[i] != reinterpretSizesOfr[i])
-        return failure();
+      // Second, check the sizes.
+      assert(extractStridedMetadata.getSizes().size() ==
+                op.getMixedSizes().size() &&
+            "Strides and sizes rank must match");
+      SmallVector<OpFoldResult> extractSizesOfr =
+          extractStridedMetadata.getConstifiedMixedSizes();
+      SmallVector<OpFoldResult> reinterpretSizesOfr =
+          op.getConstifiedMixedSizes();
+      for (unsigned i = 0; i < rank; ++i) {
+        if (extractSizesOfr[i] != reinterpretSizesOfr[i])
+          return false;
+      }
+      // Finally, check the offset.
+      assert(op.getMixedOffsets().size() == 1 &&
+            "reinterpret_cast with more than one offset should have been "
+            "rejected by the verifier");
+      OpFoldResult extractOffsetOfr =
+          extractStridedMetadata.getConstifiedMixedOffset();
+      OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
+      return extractOffsetOfr == reinterpretOffsetOfr;
+    };
+
+    if (!isReinterpretCastNoop()) {
+      // If the extract_strided_metadata / reinterpret_cast pair can't be
+      // completely folded, then we could fold the input of the
+      // extract_strided_metadata into the input of the reinterpret_cast
+      // input. For some cases (e.g., static dimensions) the 
+      // the extract_strided_metadata is eliminated by dead code elimination.
+      //
+      // reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).
+      //
+      // We can always fold the input of a extract_strided_metadata operator
+      // to the input of a reinterpret_cast operator, because they point to
+      // the same memory. Note that the reinterpret_cast does not use the
+      // layout of its input memref, only its base memory pointer which is
+      // the same as the base pointer returned by the extract_strided_metadata
+      // operator and the base pointer of the extract_strided_metadata memref
+      // input.
+      op.setOperand(0, extractStridedMetadata.getSource());
+      return success();
     }
-    // Finally, check the offset.
-    assert(op.getMixedOffsets().size() == 1 &&
-           "reinterpret_cast with more than one offset should have been "
-           "rejected by the verifier");
-    OpFoldResult extractOffsetOfr =
-        extractStridedMetadata.getConstifiedMixedOffset();
-    OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
-    if (extractOffsetOfr != reinterpretOffsetOfr)
-      return failure();
 
     // At this point, we know that the back and forth between extract strided
     // metadata and reinterpret cast is a noop. However, the final type of the
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index f1cb9c9f165be..fe91d26d5a251 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -195,10 +195,10 @@ func.func @subview_const_stride(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>
 func.func @subview_const_stride_and_offset(%0 : memref<64x8xf32, strided<[8, 1], offset: 0>>) -> memref<62x3xf32, strided<[8, 1], offset: 2>> {
   // The last "insertvalue" that populates the memref descriptor from the function arguments.
   // CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
-  // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 
   // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[BASE_ALIGNED]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[CST_OFF:.*]] = llvm.mlir.constant(2 : index) : i64
@@ -265,11 +265,11 @@ func.func @subview_mixed_static_dynamic(%0 : memref<64x4xf32, strided<[4, 1], of
 // CHECK:         %[[MEM:.*]]: memref<{{.*}}>,
 func.func @subview_leading_operands(%0 : memref<5x3xf32>, %1: memref<5x?xf32>) -> memref<3x3xf32, strided<[3, 1], offset: 6>> {
   // CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
-  // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // Alloc ptr
   // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64
   // Aligned ptr
   // CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64
+  // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[BASE_ALIGNED]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // Offset
@@ -331,9 +331,9 @@ func.func @subview_leading_operands_dynamic(%0 : memref<5x?xf32>) -> memref<3x?x
 // CHECK:         %[[MEM:.*]]: memref
 func.func @subview_rank_reducing_leading_operands(%0 : memref<5x3xf32>) -> memref<3xf32, strided<[1], offset: 3>> {
   // CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
-  // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
   // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64
   // CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64
+  // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
   // Alloc ptr
   // CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
   // Aligned ptr
@@ -356,9 +356,9 @@ func.func @subview_rank_reducing_leading_operands(%0 : memref<5x3xf32>) -> memre
 // CHECK-SAME: (%[[MEM:.*]]: memref<7xf32>)
 func.func @subview_negative_stride(%arg0 : memref<7xf32>) -> memref<7xf32, strided<[-1], offset: 6>> {
   // CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
-  // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
   // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64
   // CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64
+  // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
   // CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
   // CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[BASE_ALIGNED]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
   // CHECK: %[[CST_OFF0:.*]] = llvm.mlir.constant(6 : index) : i64
@@ -384,12 +384,12 @@ func.func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf
 // CHECK-LABEL: func @collapse_shape_static
 // CHECK-SAME: %[[ARG:.*]]: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> {
 // CHECK:           %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x3x4x1x5xf32> to !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 // CHECK:           %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
 // CHECK:           %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 // CHECK:           %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 // CHECK:           %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
 // CHECK:           %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 // CHECK:           %[[C3:.*]] = llvm.mlir.constant(3 : index) : i64
 // CHECK:           %[[DESC3:.*]] = llvm.insertvalue %[[C3]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
@@ -458,12 +458,12 @@ func.func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32
 // CHECK-LABEL: func @expand_shape_static
 // CHECK-SAME: %[[ARG:.*]]: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
 // CHECK:           %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<3x4x5xf32> to !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
 // CHECK:           %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64,
 // CHECK:           %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64,
+// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
 // CHECK:           %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
 // CHECK:           %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
 // CHECK:           %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
 // CHECK:           %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
 // CHECK:           %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
@@ -494,9 +494,9 @@ func.func @collapse_shape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref<f32>
 // CHECK-LABEL:   func.func @collapse_shape_fold_zero_dim(
 // CHECK-SAME:                                            %[[ARG:.*]]: memref<1x1xf32>) -> memref<f32> {
 // CHECK:           %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x1xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
 // CHECK:           %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64,
 // CHECK:           %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64,
+// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
 // CHECK:           %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64)>
 // CHECK:           %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64)>
 // CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
@@ -515,12 +515,12 @@ func.func @expand_shape_zero_dim(%arg0 : memref<f32>) -> memref<1x1xf32> {
 // CHECK-LABEL:   func.func @expand_shape_zero_dim(
 // CHECK-SAME:                                     %[[ARG:.*]]: memref<f32>) -> memref<1x1xf32> {
 // CHECK:           %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<f32> to !llvm.struct<(ptr, ptr, i64)>
-// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64)>
 // CHECK:           %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
 // CHECK:           %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
 // CHECK:           %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 5d8a7d3f64e8f..e7cee7cd85426 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -952,8 +952,7 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x
 //  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
-//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
 //       CHECK: return %[[RES]]
 func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
   %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
@@ -969,8 +968,7 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me
 //   CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
 //   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
-//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
 //       CHECK: return %[[RES]]
 func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
   %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index da74c73ccd7a5..1e6b0111fa4c7 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -975,7 +975,11 @@ func.func @simplify_collapse(%arg : memref<?x?x4x?x6x7xi32>)
 //
 // CHECK-LABEL: func @simplify_collapse_with_dim_of_size1(
 //  CHECK-SAME: %[[ARG:.*]]: memref<3x1xf32, strided<[2, 1]>>,
-//       CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [3], strides: [2]
+//
+//       CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<3x1xf32, strided<[2, 1]>>
+//
+//
+//       CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [3], strides: [2]
 func.func @simplify_collapse_with_dim_of_size1(%arg0: memref<3x1xf32, strided<[2,1]>>, %arg1: memref<3xf32>) {
 
   %collapse_shape = memref.collapse_shape %arg0 [[0, 1]] :

>From 0e5d3f5e1caf511a948a5063f9b2949fe57a94e0 Mon Sep 17 00:00:00 2001
From: Ivan Garcia <igarcia at vdi-ah2ddp-178.dhcp.mathworks.com>
Date: Wed, 9 Apr 2025 08:14:15 -0400
Subject: [PATCH 3/3] Addressing second round of feedback from Matthias.

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 18 +++++++++---------
 1 file changed, 9 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 895baecd8c42b..2cd3a7d6cfa9f 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2035,10 +2035,10 @@ namespace {
 /// Because we know that `offset`and `c0` will hold 0
 /// and `c4` will hold 4.
 ///
-/// If the pattern above does not match, the input of the extract_strided_metadata
-/// is always folded into the input of the reinterpret_cast operator. This allows
-/// for dead code elimination to get rid of the extract_strided_metadata in some
-/// cases.
+/// If the pattern above does not match, the input of the
+/// extract_strided_metadata is always folded into the input of the
+/// reinterpret_cast operator. This allows for dead code elimination to get rid
+/// of the extract_strided_metadata in some cases.
 struct ReinterpretCastOpExtractStridedMetadataFolder
     : public OpRewritePattern<ReinterpretCastOp> {
 public:
@@ -2070,8 +2070,8 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
 
       // Second, check the sizes.
       assert(extractStridedMetadata.getSizes().size() ==
-                op.getMixedSizes().size() &&
-            "Strides and sizes rank must match");
+                 op.getMixedSizes().size() &&
+             "Strides and sizes rank must match");
       SmallVector<OpFoldResult> extractSizesOfr =
           extractStridedMetadata.getConstifiedMixedSizes();
       SmallVector<OpFoldResult> reinterpretSizesOfr =
@@ -2082,8 +2082,8 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
       }
       // Finally, check the offset.
       assert(op.getMixedOffsets().size() == 1 &&
-            "reinterpret_cast with more than one offset should have been "
-            "rejected by the verifier");
+             "reinterpret_cast with more than one offset should have been "
+             "rejected by the verifier");
       OpFoldResult extractOffsetOfr =
           extractStridedMetadata.getConstifiedMixedOffset();
       OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
@@ -2094,7 +2094,7 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
       // If the extract_strided_metadata / reinterpret_cast pair can't be
       // completely folded, then we could fold the input of the
       // extract_strided_metadata into the input of the reinterpret_cast
-      // input. For some cases (e.g., static dimensions) the 
+      // input. For some cases (e.g., static dimensions) the
       // the extract_strided_metadata is eliminated by dead code elimination.
       //
       // reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).



More information about the Mlir-commits mailing list