[Mlir-commits] [mlir] Folding extract_strided_metadata input into reinterpret_cast on constant layout (PR #134845)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 8 04:49:44 PDT 2025


https://github.com/ivangarcia44 created https://github.com/llvm/llvm-project/pull/134845

We can always fold the input of a extract_strided_metadata operator to the input of a reinterpret_cast operator, because they point to the same memory. Note that the reinterpret_cast does not use the layout of its input memref, only its base memory pointer which is the same as the base pointer returned by the extract_strided_metadata operator and the base pointer of the extract_strided_metadata memref input.

This folding is only profitable when the reinterpret_cast layout is constant, because the extract_strided_metadata gets eliminated by dead code elimination. For non-constant folding we don’t get the extract_strided_metadata node eliminated and one of the LLVM tests regress in performance because the folding gets in the way of another optimization. For this reason, the folding is only done on constant layout.

Operations like expand_shape, collapse_shape, and subview are lowered to a pair of extract_strided_metadata and reinterpret_cast like this:
      
%base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %input_memref : memref<ID1x...xIDNxBaseType> -> memref<f32>, index, index, index, index, index

%reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [%o1], sizes: [%d1,...,%dN], strides: [%s1,...,%N] : memref<f32> to memref<OD1x...xODNxBaseType >

In many cases the input of the extract_strided_metadata input can be passed directly into the input of the reinterpret_cast operation like this (see how %base_buffer is replaced by %input_memref in the reinterpret_cast above and the input type is updated):

%base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %input_memref : memref<ID1x...xIDNxBaseType> -> memref<f32>, index, index, index, index, index
%reinterpret_cast = memref.reinterpret_cast %input_memref to offset: [%o1], sizes: [%d1,...,%dN], strides: [%s1,...,%N] : memref<ID1x...xIDNxBaseType> to memref<OD1x...xODNxBaseType >

When dealing with static dimensions, the extract_strided_metatdata will become deadcode and we end up only with a reinterpret_cast:

%reinterpret_cast = memref.reinterpret_cast %input_memref to offset: [%o1], sizes: [%d1,...,%dN], strides: [%s1,...,%N] : memref<ID1x...xIDNxBaseType> to memref<OD1x...xODNxBaseType >

Note that reinterpret_cast only reads the base memory pointer from the input memref (%input_memref above), which is equivalent to the %base_buffer returned by the extract_strided_metadata operation. Hence it is legal always to use the extract_strided_metadata input memref directly in the reinterpret_cast. Note that since this is a pointer, this operation is legal even when the base pointer values are modified between the operation pair.


>From dd31b5f193e7e1f1df8a0c6da3a5c304f85da990 Mon Sep 17 00:00:00 2001
From: Ivan Garcia <igarcia at vdi-ah2ddp-178.dhcp.mathworks.com>
Date: Tue, 8 Apr 2025 07:37:41 -0400
Subject: [PATCH] Folding extract_strided_metadata input into reinterpret_cast
 on constant layout.

---
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       |  3 ++
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      | 39 ++++++++++++++++++-
 .../expand-then-convert-to-llvm.mlir          | 22 +++++------
 .../MemRef/expand-strided-metadata.mlir       |  6 +--
 4 files changed, 53 insertions(+), 17 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 3edc2433c85ea..990a282771502 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1440,6 +1440,9 @@ def MemRef_ReinterpretCastOp
     SmallVector<OpFoldResult> getConstifiedMixedStrides();
     /// Similar to `getConstifiedMixedSizes` but for the offset.
     OpFoldResult getConstifiedMixedOffset();
+    /// Returns true if the reinterpret cast operation's offset, stride, and 
+    /// size are all constant.
+    bool isLayoutConstant();
   }];
 
   let hasFolder = 1;
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 123666848f83a..629d0d8d425b1 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1124,7 +1124,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
         }
       } // else dim.getIndex is a block argument to reshape->getBlock and
         // dominates reshape
-    }   // Check condition 2
+    } // Check condition 2
     else if (dim->getBlock() != reshape->getBlock() &&
              !dim.getIndex().getParentRegion()->isProperAncestor(
                  reshape->getParentRegion())) {
@@ -1948,6 +1948,27 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
     if (auto prev = src.getDefiningOp<CastOp>())
       return prev.getSource();
 
+    // reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).
+    //
+    // We can always fold the input of a extract_strided_metadata operator
+    // to the input of a reinterpret_cast operator, because they point to
+    // the same memory. Note that the reinterpret_cast does not use the
+    // layout of its input memref, only its base memory pointer which is
+    // the same as the base pointer returned by the extract_strided_metadata
+    // operator and the base pointer of the extract_strided_metadata memref
+    // input. This folding is only profitable when the reinterpret_cast
+    // layout is constant, because the extract_strided_metadata gets
+    // eliminated by dead code elimination. For non-constant folding we don’t
+    // get the extract_strided_metadata node eliminated and one of the LLVM
+    // tests regress in performance because the folding gets in the way of
+    // another optimization. For this reason the folding is only done on
+    // constant layout.
+    if (auto prev = src.getDefiningOp<ExtractStridedMetadataOp>()) {
+      if (isLayoutConstant()) {
+        return prev.getSource();
+      }
+    }
+
     // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
     // are 0.
     if (auto prev = src.getDefiningOp<SubViewOp>())
@@ -1973,6 +1994,22 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
   return nullptr;
 }
 
+bool ReinterpretCastOp::isLayoutConstant() {
+  if (llvm::all_of(
+          getOffsets(),
+          [](OpFoldResult val) { return isConstantIntValue(val, 0); }) &&
+      llvm::all_of(
+          getStrides(),
+          [](OpFoldResult val) { return isConstantIntValue(val, 0); }) &&
+      llvm::all_of(getSizes(), [](OpFoldResult val) {
+        return isConstantIntValue(val, 0);
+      })) {
+    return true;
+  } else {
+    return false;
+  }
+}
+
 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
   SmallVector<OpFoldResult> values = getMixedSizes();
   constifyIndexValues(values, getType(), getContext(), getConstantSizes,
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index fe91d26d5a251..f1cb9c9f165be 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -195,10 +195,10 @@ func.func @subview_const_stride(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>
 func.func @subview_const_stride_and_offset(%0 : memref<64x8xf32, strided<[8, 1], offset: 0>>) -> memref<62x3xf32, strided<[8, 1], offset: 2>> {
   // The last "insertvalue" that populates the memref descriptor from the function arguments.
   // CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
+  // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 
   // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-  // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[BASE_ALIGNED]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[CST_OFF:.*]] = llvm.mlir.constant(2 : index) : i64
@@ -265,11 +265,11 @@ func.func @subview_mixed_static_dynamic(%0 : memref<64x4xf32, strided<[4, 1], of
 // CHECK:         %[[MEM:.*]]: memref<{{.*}}>,
 func.func @subview_leading_operands(%0 : memref<5x3xf32>, %1: memref<5x?xf32>) -> memref<3x3xf32, strided<[3, 1], offset: 6>> {
   // CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
+  // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // Alloc ptr
   // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64
   // Aligned ptr
   // CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64
-  // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[BASE_ALIGNED]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // Offset
@@ -331,9 +331,9 @@ func.func @subview_leading_operands_dynamic(%0 : memref<5x?xf32>) -> memref<3x?x
 // CHECK:         %[[MEM:.*]]: memref
 func.func @subview_rank_reducing_leading_operands(%0 : memref<5x3xf32>) -> memref<3xf32, strided<[1], offset: 3>> {
   // CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
+  // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
   // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64
   // CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64
-  // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
   // Alloc ptr
   // CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
   // Aligned ptr
@@ -356,9 +356,9 @@ func.func @subview_rank_reducing_leading_operands(%0 : memref<5x3xf32>) -> memre
 // CHECK-SAME: (%[[MEM:.*]]: memref<7xf32>)
 func.func @subview_negative_stride(%arg0 : memref<7xf32>) -> memref<7xf32, strided<[-1], offset: 6>> {
   // CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
+  // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
   // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64
   // CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64
-  // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
   // CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
   // CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[BASE_ALIGNED]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
   // CHECK: %[[CST_OFF0:.*]] = llvm.mlir.constant(6 : index) : i64
@@ -384,12 +384,12 @@ func.func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf
 // CHECK-LABEL: func @collapse_shape_static
 // CHECK-SAME: %[[ARG:.*]]: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> {
 // CHECK:           %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x3x4x1x5xf32> to !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 // CHECK:           %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
 // CHECK:           %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
-// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 // CHECK:           %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 // CHECK:           %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
 // CHECK:           %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 // CHECK:           %[[C3:.*]] = llvm.mlir.constant(3 : index) : i64
 // CHECK:           %[[DESC3:.*]] = llvm.insertvalue %[[C3]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
@@ -458,12 +458,12 @@ func.func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32
 // CHECK-LABEL: func @expand_shape_static
 // CHECK-SAME: %[[ARG:.*]]: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
 // CHECK:           %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<3x4x5xf32> to !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
 // CHECK:           %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64,
 // CHECK:           %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64,
-// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
-// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
 // CHECK:           %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
 // CHECK:           %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
 // CHECK:           %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
 // CHECK:           %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
 // CHECK:           %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
@@ -494,9 +494,9 @@ func.func @collapse_shape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref<f32>
 // CHECK-LABEL:   func.func @collapse_shape_fold_zero_dim(
 // CHECK-SAME:                                            %[[ARG:.*]]: memref<1x1xf32>) -> memref<f32> {
 // CHECK:           %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x1xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
 // CHECK:           %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64,
 // CHECK:           %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64,
-// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
 // CHECK:           %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64)>
 // CHECK:           %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64)>
 // CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
@@ -515,12 +515,12 @@ func.func @expand_shape_zero_dim(%arg0 : memref<f32>) -> memref<1x1xf32> {
 // CHECK-LABEL:   func.func @expand_shape_zero_dim(
 // CHECK-SAME:                                     %[[ARG:.*]]: memref<f32>) -> memref<1x1xf32> {
 // CHECK:           %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<f32> to !llvm.struct<(ptr, ptr, i64)>
+// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64)>
 // CHECK:           %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64)>
-// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
-// CHECK:           %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
 // CHECK:           %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
 // CHECK:           %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index 1e6b0111fa4c7..da74c73ccd7a5 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -975,11 +975,7 @@ func.func @simplify_collapse(%arg : memref<?x?x4x?x6x7xi32>)
 //
 // CHECK-LABEL: func @simplify_collapse_with_dim_of_size1(
 //  CHECK-SAME: %[[ARG:.*]]: memref<3x1xf32, strided<[2, 1]>>,
-//
-//       CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<3x1xf32, strided<[2, 1]>>
-//
-//
-//       CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [3], strides: [2]
+//       CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [3], strides: [2]
 func.func @simplify_collapse_with_dim_of_size1(%arg0: memref<3x1xf32, strided<[2,1]>>, %arg1: memref<3xf32>) {
 
   %collapse_shape = memref.collapse_shape %arg0 [[0, 1]] :



More information about the Mlir-commits mailing list