[Mlir-commits] [mlir] Folding extract_strided_metadata input into reinterpret_cast on constant layout (PR #134845)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 8 13:38:04 PDT 2025
https://github.com/ivangarcia44 updated https://github.com/llvm/llvm-project/pull/134845
>From dd31b5f193e7e1f1df8a0c6da3a5c304f85da990 Mon Sep 17 00:00:00 2001
From: Ivan Garcia <igarcia at vdi-ah2ddp-178.dhcp.mathworks.com>
Date: Tue, 8 Apr 2025 07:37:41 -0400
Subject: [PATCH 1/2] Folding extract_strided_metadata input into
reinterpret_cast on constant layout.
---
.../mlir/Dialect/MemRef/IR/MemRefOps.td | 3 ++
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 39 ++++++++++++++++++-
.../expand-then-convert-to-llvm.mlir | 22 +++++------
.../MemRef/expand-strided-metadata.mlir | 6 +--
4 files changed, 53 insertions(+), 17 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 3edc2433c85ea..990a282771502 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1440,6 +1440,9 @@ def MemRef_ReinterpretCastOp
SmallVector<OpFoldResult> getConstifiedMixedStrides();
/// Similar to `getConstifiedMixedSizes` but for the offset.
OpFoldResult getConstifiedMixedOffset();
+ /// Returns true if the reinterpret cast operation's offset, stride, and
+ /// size are all constant.
+ bool isLayoutConstant();
}];
let hasFolder = 1;
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 123666848f83a..629d0d8d425b1 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1124,7 +1124,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
}
} // else dim.getIndex is a block argument to reshape->getBlock and
// dominates reshape
- } // Check condition 2
+ } // Check condition 2
else if (dim->getBlock() != reshape->getBlock() &&
!dim.getIndex().getParentRegion()->isProperAncestor(
reshape->getParentRegion())) {
@@ -1948,6 +1948,27 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
if (auto prev = src.getDefiningOp<CastOp>())
return prev.getSource();
+ // reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).
+ //
+ // We can always fold the input of a extract_strided_metadata operator
+ // to the input of a reinterpret_cast operator, because they point to
+ // the same memory. Note that the reinterpret_cast does not use the
+ // layout of its input memref, only its base memory pointer which is
+ // the same as the base pointer returned by the extract_strided_metadata
+ // operator and the base pointer of the extract_strided_metadata memref
+ // input. This folding is only profitable when the reinterpret_cast
+ // layout is constant, because the extract_strided_metadata gets
+ // eliminated by dead code elimination. For non-constant folding we don’t
+ // get the extract_strided_metadata node eliminated and one of the LLVM
+ // tests regress in performance because the folding gets in the way of
+ // another optimization. For this reason the folding is only done on
+ // constant layout.
+ if (auto prev = src.getDefiningOp<ExtractStridedMetadataOp>()) {
+ if (isLayoutConstant()) {
+ return prev.getSource();
+ }
+ }
+
// reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
// are 0.
if (auto prev = src.getDefiningOp<SubViewOp>())
@@ -1973,6 +1994,22 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
return nullptr;
}
+bool ReinterpretCastOp::isLayoutConstant() {
+ if (llvm::all_of(
+ getOffsets(),
+ [](OpFoldResult val) { return isConstantIntValue(val, 0); }) &&
+ llvm::all_of(
+ getStrides(),
+ [](OpFoldResult val) { return isConstantIntValue(val, 0); }) &&
+ llvm::all_of(getSizes(), [](OpFoldResult val) {
+ return isConstantIntValue(val, 0);
+ })) {
+ return true;
+ } else {
+ return false;
+ }
+}
+
SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
SmallVector<OpFoldResult> values = getMixedSizes();
constifyIndexValues(values, getType(), getContext(), getConstantSizes,
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index fe91d26d5a251..f1cb9c9f165be 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -195,10 +195,10 @@ func.func @subview_const_stride(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>
func.func @subview_const_stride_and_offset(%0 : memref<64x8xf32, strided<[8, 1], offset: 0>>) -> memref<62x3xf32, strided<[8, 1], offset: 2>> {
// The last "insertvalue" that populates the memref descriptor from the function arguments.
// CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
+ // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
- // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[BASE_ALIGNED]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[CST_OFF:.*]] = llvm.mlir.constant(2 : index) : i64
@@ -265,11 +265,11 @@ func.func @subview_mixed_static_dynamic(%0 : memref<64x4xf32, strided<[4, 1], of
// CHECK: %[[MEM:.*]]: memref<{{.*}}>,
func.func @subview_leading_operands(%0 : memref<5x3xf32>, %1: memref<5x?xf32>) -> memref<3x3xf32, strided<[3, 1], offset: 6>> {
// CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
+ // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// Alloc ptr
// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64
// Aligned ptr
// CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64
- // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[BASE_ALIGNED]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// Offset
@@ -331,9 +331,9 @@ func.func @subview_leading_operands_dynamic(%0 : memref<5x?xf32>) -> memref<3x?x
// CHECK: %[[MEM:.*]]: memref
func.func @subview_rank_reducing_leading_operands(%0 : memref<5x3xf32>) -> memref<3xf32, strided<[1], offset: 3>> {
// CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
+ // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64
// CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64
- // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// Alloc ptr
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// Aligned ptr
@@ -356,9 +356,9 @@ func.func @subview_rank_reducing_leading_operands(%0 : memref<5x3xf32>) -> memre
// CHECK-SAME: (%[[MEM:.*]]: memref<7xf32>)
func.func @subview_negative_stride(%arg0 : memref<7xf32>) -> memref<7xf32, strided<[-1], offset: 6>> {
// CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
+ // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64
// CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64
- // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[BASE_ALIGNED]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[CST_OFF0:.*]] = llvm.mlir.constant(6 : index) : i64
@@ -384,12 +384,12 @@ func.func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf
// CHECK-LABEL: func @collapse_shape_static
// CHECK-SAME: %[[ARG:.*]]: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> {
// CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x3x4x1x5xf32> to !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
// CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : index) : i64
// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[C3]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
@@ -458,12 +458,12 @@ func.func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32
// CHECK-LABEL: func @expand_shape_static
// CHECK-SAME: %[[ARG:.*]]: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
// CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<3x4x5xf32> to !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
+// CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
// CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64,
// CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64,
-// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
@@ -494,9 +494,9 @@ func.func @collapse_shape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref<f32>
// CHECK-LABEL: func.func @collapse_shape_fold_zero_dim(
// CHECK-SAME: %[[ARG:.*]]: memref<1x1xf32>) -> memref<f32> {
// CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x1xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
// CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64,
// CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64,
-// CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64)>
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64)>
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
@@ -515,12 +515,12 @@ func.func @expand_shape_zero_dim(%arg0 : memref<f32>) -> memref<1x1xf32> {
// CHECK-LABEL: func.func @expand_shape_zero_dim(
// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<1x1xf32> {
// CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<f32> to !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64)>
// CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64)>
-// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index 1e6b0111fa4c7..da74c73ccd7a5 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -975,11 +975,7 @@ func.func @simplify_collapse(%arg : memref<?x?x4x?x6x7xi32>)
//
// CHECK-LABEL: func @simplify_collapse_with_dim_of_size1(
// CHECK-SAME: %[[ARG:.*]]: memref<3x1xf32, strided<[2, 1]>>,
-//
-// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<3x1xf32, strided<[2, 1]>>
-//
-//
-// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [3], strides: [2]
+// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [3], strides: [2]
func.func @simplify_collapse_with_dim_of_size1(%arg0: memref<3x1xf32, strided<[2,1]>>, %arg1: memref<3xf32>) {
%collapse_shape = memref.collapse_shape %arg0 [[0, 1]] :
>From 94f4b167a99cd1435b5abc88654edacc2dd7edbf Mon Sep 17 00:00:00 2001
From: Ivan Garcia <igarcia at vdi-ah2ddp-178.dhcp.mathworks.com>
Date: Tue, 8 Apr 2025 16:37:36 -0400
Subject: [PATCH 2/2] Fixing issue found by Matthias Springer and
improving/simplifying change.
---
.../mlir/Dialect/MemRef/IR/MemRefOps.td | 3 -
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 121 ++++++++----------
.../expand-then-convert-to-llvm.mlir | 22 ++--
mlir/test/Dialect/MemRef/canonicalize.mlir | 6 +-
.../MemRef/expand-strided-metadata.mlir | 6 +-
5 files changed, 73 insertions(+), 85 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 990a282771502..3edc2433c85ea 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1440,9 +1440,6 @@ def MemRef_ReinterpretCastOp
SmallVector<OpFoldResult> getConstifiedMixedStrides();
/// Similar to `getConstifiedMixedSizes` but for the offset.
OpFoldResult getConstifiedMixedOffset();
- /// Returns true if the reinterpret cast operation's offset, stride, and
- /// size are all constant.
- bool isLayoutConstant();
}];
let hasFolder = 1;
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 629d0d8d425b1..895baecd8c42b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1948,27 +1948,6 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
if (auto prev = src.getDefiningOp<CastOp>())
return prev.getSource();
- // reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).
- //
- // We can always fold the input of a extract_strided_metadata operator
- // to the input of a reinterpret_cast operator, because they point to
- // the same memory. Note that the reinterpret_cast does not use the
- // layout of its input memref, only its base memory pointer which is
- // the same as the base pointer returned by the extract_strided_metadata
- // operator and the base pointer of the extract_strided_metadata memref
- // input. This folding is only profitable when the reinterpret_cast
- // layout is constant, because the extract_strided_metadata gets
- // eliminated by dead code elimination. For non-constant folding we don’t
- // get the extract_strided_metadata node eliminated and one of the LLVM
- // tests regress in performance because the folding gets in the way of
- // another optimization. For this reason the folding is only done on
- // constant layout.
- if (auto prev = src.getDefiningOp<ExtractStridedMetadataOp>()) {
- if (isLayoutConstant()) {
- return prev.getSource();
- }
- }
-
// reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
// are 0.
if (auto prev = src.getDefiningOp<SubViewOp>())
@@ -1994,22 +1973,6 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
return nullptr;
}
-bool ReinterpretCastOp::isLayoutConstant() {
- if (llvm::all_of(
- getOffsets(),
- [](OpFoldResult val) { return isConstantIntValue(val, 0); }) &&
- llvm::all_of(
- getStrides(),
- [](OpFoldResult val) { return isConstantIntValue(val, 0); }) &&
- llvm::all_of(getSizes(), [](OpFoldResult val) {
- return isConstantIntValue(val, 0);
- })) {
- return true;
- } else {
- return false;
- }
-}
-
SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
SmallVector<OpFoldResult> values = getMixedSizes();
constifyIndexValues(values, getType(), getContext(), getConstantSizes,
@@ -2071,6 +2034,11 @@ namespace {
/// ```
/// Because we know that `offset`and `c0` will hold 0
/// and `c4` will hold 4.
+///
+/// If the pattern above does not match, the input of the extract_strided_metadata
+/// is always folded into the input of the reinterpret_cast operator. This allows
+/// for dead code elimination to get rid of the extract_strided_metadata in some
+/// cases.
struct ReinterpretCastOpExtractStridedMetadataFolder
: public OpRewritePattern<ReinterpretCastOp> {
public:
@@ -2082,44 +2050,65 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
if (!extractStridedMetadata)
return failure();
+
// Check if the reinterpret cast reconstructs a memref with the exact same
// properties as the extract strided metadata.
-
- // First, check that the strides are the same.
SmallVector<OpFoldResult> extractStridesOfr =
extractStridedMetadata.getConstifiedMixedStrides();
SmallVector<OpFoldResult> reinterpretStridesOfr =
op.getConstifiedMixedStrides();
- if (extractStridesOfr.size() != reinterpretStridesOfr.size())
- return failure();
+ auto isReinterpretCastNoop = [&]() -> bool {
+ // First, check that the strides are the same.
+ if (extractStridesOfr.size() != reinterpretStridesOfr.size())
+ return false;
- unsigned rank = op.getType().getRank();
- for (unsigned i = 0; i < rank; ++i) {
- if (extractStridesOfr[i] != reinterpretStridesOfr[i])
- return failure();
- }
+ unsigned rank = op.getType().getRank();
+ for (unsigned i = 0; i < rank; ++i) {
+ if (extractStridesOfr[i] != reinterpretStridesOfr[i])
+ return false;
+ }
- // Second, check the sizes.
- assert(extractStridedMetadata.getSizes().size() ==
- op.getMixedSizes().size() &&
- "Strides and sizes rank must match");
- SmallVector<OpFoldResult> extractSizesOfr =
- extractStridedMetadata.getConstifiedMixedSizes();
- SmallVector<OpFoldResult> reinterpretSizesOfr =
- op.getConstifiedMixedSizes();
- for (unsigned i = 0; i < rank; ++i) {
- if (extractSizesOfr[i] != reinterpretSizesOfr[i])
- return failure();
+ // Second, check the sizes.
+ assert(extractStridedMetadata.getSizes().size() ==
+ op.getMixedSizes().size() &&
+ "Strides and sizes rank must match");
+ SmallVector<OpFoldResult> extractSizesOfr =
+ extractStridedMetadata.getConstifiedMixedSizes();
+ SmallVector<OpFoldResult> reinterpretSizesOfr =
+ op.getConstifiedMixedSizes();
+ for (unsigned i = 0; i < rank; ++i) {
+ if (extractSizesOfr[i] != reinterpretSizesOfr[i])
+ return false;
+ }
+ // Finally, check the offset.
+ assert(op.getMixedOffsets().size() == 1 &&
+ "reinterpret_cast with more than one offset should have been "
+ "rejected by the verifier");
+ OpFoldResult extractOffsetOfr =
+ extractStridedMetadata.getConstifiedMixedOffset();
+ OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
+ return extractOffsetOfr == reinterpretOffsetOfr;
+ };
+
+ if (!isReinterpretCastNoop()) {
+ // If the extract_strided_metadata / reinterpret_cast pair can't be
+ // completely folded, then we could fold the input of the
+ // extract_strided_metadata into the input of the reinterpret_cast
+ // input. For some cases (e.g., static dimensions) the
+ // the extract_strided_metadata is eliminated by dead code elimination.
+ //
+ // reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).
+ //
+ // We can always fold the input of a extract_strided_metadata operator
+ // to the input of a reinterpret_cast operator, because they point to
+ // the same memory. Note that the reinterpret_cast does not use the
+ // layout of its input memref, only its base memory pointer which is
+ // the same as the base pointer returned by the extract_strided_metadata
+ // operator and the base pointer of the extract_strided_metadata memref
+ // input.
+ op.setOperand(0, extractStridedMetadata.getSource());
+ return success();
}
- // Finally, check the offset.
- assert(op.getMixedOffsets().size() == 1 &&
- "reinterpret_cast with more than one offset should have been "
- "rejected by the verifier");
- OpFoldResult extractOffsetOfr =
- extractStridedMetadata.getConstifiedMixedOffset();
- OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
- if (extractOffsetOfr != reinterpretOffsetOfr)
- return failure();
// At this point, we know that the back and forth between extract strided
// metadata and reinterpret cast is a noop. However, the final type of the
diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
index f1cb9c9f165be..fe91d26d5a251 100644
--- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
@@ -195,10 +195,10 @@ func.func @subview_const_stride(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>
func.func @subview_const_stride_and_offset(%0 : memref<64x8xf32, strided<[8, 1], offset: 0>>) -> memref<62x3xf32, strided<[8, 1], offset: 2>> {
// The last "insertvalue" that populates the memref descriptor from the function arguments.
// CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
- // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[BASE_ALIGNED]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[CST_OFF:.*]] = llvm.mlir.constant(2 : index) : i64
@@ -265,11 +265,11 @@ func.func @subview_mixed_static_dynamic(%0 : memref<64x4xf32, strided<[4, 1], of
// CHECK: %[[MEM:.*]]: memref<{{.*}}>,
func.func @subview_leading_operands(%0 : memref<5x3xf32>, %1: memref<5x?xf32>) -> memref<3x3xf32, strided<[3, 1], offset: 6>> {
// CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
- // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// Alloc ptr
// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64
// Aligned ptr
// CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64
+ // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[BASE_ALIGNED]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// Offset
@@ -331,9 +331,9 @@ func.func @subview_leading_operands_dynamic(%0 : memref<5x?xf32>) -> memref<3x?x
// CHECK: %[[MEM:.*]]: memref
func.func @subview_rank_reducing_leading_operands(%0 : memref<5x3xf32>) -> memref<3xf32, strided<[1], offset: 3>> {
// CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
- // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64
// CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64
+ // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// Alloc ptr
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// Aligned ptr
@@ -356,9 +356,9 @@ func.func @subview_rank_reducing_leading_operands(%0 : memref<5x3xf32>) -> memre
// CHECK-SAME: (%[[MEM:.*]]: memref<7xf32>)
func.func @subview_negative_stride(%arg0 : memref<7xf32>) -> memref<7xf32, strided<[-1], offset: 6>> {
// CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]]
- // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][0] : !llvm.struct<(ptr, ptr, i64
// CHECK: %[[BASE_ALIGNED:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr, ptr, i64
+ // CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[BASE_ALIGNED]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[CST_OFF0:.*]] = llvm.mlir.constant(6 : index) : i64
@@ -384,12 +384,12 @@ func.func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf
// CHECK-LABEL: func @collapse_shape_static
// CHECK-SAME: %[[ARG:.*]]: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> {
// CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x3x4x1x5xf32> to !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
// CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
+// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : index) : i64
// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[C3]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
@@ -458,12 +458,12 @@ func.func @expand_shape_static(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32
// CHECK-LABEL: func @expand_shape_static
// CHECK-SAME: %[[ARG:.*]]: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
// CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<3x4x5xf32> to !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
-// CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
// CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64,
// CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64,
+// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
-// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)>
@@ -494,9 +494,9 @@ func.func @collapse_shape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref<f32>
// CHECK-LABEL: func.func @collapse_shape_fold_zero_dim(
// CHECK-SAME: %[[ARG:.*]]: memref<1x1xf32>) -> memref<f32> {
// CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x1xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
// CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64,
// CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64,
+// CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)>
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64)>
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64)>
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
@@ -515,12 +515,12 @@ func.func @expand_shape_zero_dim(%arg0 : memref<f32>) -> memref<1x1xf32> {
// CHECK-LABEL: func.func @expand_shape_zero_dim(
// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<1x1xf32> {
// CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<f32> to !llvm.struct<(ptr, ptr, i64)>
-// CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64)>
// CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64)>
+// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 5d8a7d3f64e8f..e7cee7cd85426 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -952,8 +952,7 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
-// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
+// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
// CHECK: return %[[RES]]
func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
@@ -969,8 +968,7 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
-// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
+// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
// CHECK: return %[[RES]]
func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index da74c73ccd7a5..1e6b0111fa4c7 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -975,7 +975,11 @@ func.func @simplify_collapse(%arg : memref<?x?x4x?x6x7xi32>)
//
// CHECK-LABEL: func @simplify_collapse_with_dim_of_size1(
// CHECK-SAME: %[[ARG:.*]]: memref<3x1xf32, strided<[2, 1]>>,
-// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [3], strides: [2]
+//
+// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<3x1xf32, strided<[2, 1]>>
+//
+//
+// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [3], strides: [2]
func.func @simplify_collapse_with_dim_of_size1(%arg0: memref<3x1xf32, strided<[2,1]>>, %arg1: memref<3xf32>) {
%collapse_shape = memref.collapse_shape %arg0 [[0, 1]] :
More information about the Mlir-commits
mailing list