[Mlir-commits] [mlir] [mlir] remove folder on MemRef::ExtractStridedMetadataOp (PR #88043)

Nirvedh Meshram llvmlistbot at llvm.org
Mon Apr 8 14:52:37 PDT 2024


https://github.com/nirvedhmeshram updated https://github.com/llvm/llvm-project/pull/88043

>From 38382affec8278369d835c560d4b3bad7fafd7ea Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Mon, 8 Apr 2024 14:59:51 -0600
Subject: [PATCH 1/2] [mlir] remove folder on MemRef::ExtractStridedMetadataOp
 that is causing down stream bugs

---
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       |  2 -
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      | 47 -------------
 mlir/test/Dialect/MemRef/canonicalize.mlir    | 66 -------------------
 .../MemRef/expand-strided-metadata.mlir       | 30 ++-------
 4 files changed, 4 insertions(+), 141 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 39e66cd9e6e5ab..b9b7963736c9b9 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -987,8 +987,6 @@ def MemRef_ExtractStridedMetadataOp : MemRef_Op<"extract_strided_metadata", [
 
     ::mlir::Value getViewSource() { return getSource(); }
   }];
-
-  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 836dcb8f329e70..9729f98088ce02 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1374,53 +1374,6 @@ void ExtractStridedMetadataOp::getAsmResultNames(
   }
 }
 
-/// Helper function to perform the replacement of all constant uses of `values`
-/// by a materialized constant extracted from `maybeConstants`.
-/// `values` and `maybeConstants` are expected to have the same size.
-template <typename Container>
-static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
-                                  Container values,
-                                  ArrayRef<OpFoldResult> maybeConstants) {
-  assert(values.size() == maybeConstants.size() &&
-         " expected values and maybeConstants of the same size");
-  bool atLeastOneReplacement = false;
-  for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
-    // Don't materialize a constant if there are no uses: this would indice
-    // infinite loops in the driver.
-    if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
-      continue;
-    assert(maybeConstant.template is<Attribute>() &&
-           "The constified value should be either unchanged (i.e., == result) "
-           "or a constant");
-    Value constantVal = rewriter.create<arith::ConstantIndexOp>(
-        loc, llvm::cast<IntegerAttr>(maybeConstant.template get<Attribute>())
-                 .getInt());
-    for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
-      // modifyOpInPlace: lambda cannot capture structured bindings in C++17
-      // yet.
-      op->replaceUsesOfWith(result, constantVal);
-      atLeastOneReplacement = true;
-    }
-  }
-  return atLeastOneReplacement;
-}
-
-LogicalResult
-ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
-                               SmallVectorImpl<OpFoldResult> &results) {
-  OpBuilder builder(*this);
-
-  bool atLeastOneReplacement = replaceConstantUsesOf(
-      builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
-      getConstifiedMixedOffset());
-  atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
-                                                 getConstifiedMixedSizes());
-  atLeastOneReplacement |= replaceConstantUsesOf(
-      builder, getLoc(), getStrides(), getConstifiedMixedStrides());
-
-  return success(atLeastOneReplacement);
-}
-
 SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
   SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
   constifyIndexValues(values, getSource().getType(), getContext(),
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 506ed1f1c10b10..0105c976661105 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -870,38 +870,6 @@ func.func @reinterpret_of_subview(%arg : memref<?xi8>, %size1: index, %size2: in
 
 // -----
 
-// Check that a reinterpret cast of an equivalent extract strided metadata
-// is canonicalized to a plain cast when the destination type is different
-// than the type of the original memref.
-// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_type_mistach
-//  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-//       CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref<?x?xf32,
-//       CHECK: return %[[CAST]]
-func.func @reinterpret_of_extract_strided_metadata_w_type_mistach(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
-  %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
-  %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
-  return %m2 : memref<?x?xf32, strided<[?, ?], offset: ?>>
-}
-
-// -----
-
-// Similar to reinterpret_of_extract_strided_metadata_w_type_mistach except that
-// we check that the match happen when the static information has been folded.
-// E.g., in this case, we know that size of dim 0 is 8 and size of dim 1 is 2.
-// So even if we don't use the values sizes#0, sizes#1, as long as they have the
-// same constant value, the match is valid.
-// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_constants
-//  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-//       CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref<?x?xf32,
-//       CHECK: return %[[CAST]]
-func.func @reinterpret_of_extract_strided_metadata_w_constants(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
-  %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
-  %c8 = arith.constant 8: index
-  %m2 = memref.reinterpret_cast %base to offset: [0], sizes: [%c8, 2], strides: [2, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
-  return %m2 : memref<?x?xf32, strided<[?, ?], offset: ?>>
-}
-// -----
-
 // Check that a reinterpret cast of an equivalent extract strided metadata
 // is completely removed when the original memref has the same type.
 // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_same_type
@@ -915,40 +883,6 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x
 
 // -----
 
-// Check that we don't simplify reinterpret cast of extract strided metadata
-// when the strides don't match.
-// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_stride
-//  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
-//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
-//       CHECK: return %[[RES]]
-func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
-  %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
-  %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
-  return %m2 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
-}
-// -----
-
-// Check that we don't simplify reinterpret cast of extract strided metadata
-// when the offset doesn't match.
-// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset
-//  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-//   CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
-//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
-//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
-//       CHECK: return %[[RES]]
-func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
-  %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
-  %m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
-  return %m2 : memref<?x?xf32, strided<[?, ?], offset: ?>>
-}
-
-// -----
-
 func.func @canonicalize_rank_reduced_subview(%arg0 : memref<8x?xf32>,
     %arg1 : index) -> memref<?xf32, strided<[?], offset: ?>> {
   %c0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index 28b70043005940..70a5310cf9a45c 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -1,25 +1,5 @@
 // RUN: mlir-opt --expand-strided-metadata -split-input-file %s -o - | FileCheck %s
 
-// CHECK-LABEL: func @extract_strided_metadata_constants
-//  CHECK-SAME: (%[[ARG:.*]]: memref<5x4xf32, strided<[4, 1], offset: 2>>)
-func.func @extract_strided_metadata_constants(%base: memref<5x4xf32, strided<[4, 1], offset: 2>>)
-    -> (memref<f32>, index, index, index, index, index) {
-  //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-  //   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-  //   CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
-  //   CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
-
-  //       CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
-  %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %base :
-    memref<5x4xf32, strided<[4,1], offset:2>>
-    -> memref<f32>, index, index, index, index, index
-
-  // CHECK: %[[BASE]], %[[C2]], %[[C5]], %[[C4]], %[[C4]], %[[C1]]
-  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
-    memref<f32>, index, index, index, index, index
-}
-
-// -----
 
 // Check that we simplify subview(src) into:
 // base, offset, sizes, strides xtract_strided_metadata src
@@ -1070,6 +1050,7 @@ func.func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride
 //
 //   CHECK-DAG: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
 //   CHECK-DAG: #[[$STRIDE0_MAP:.*]] = affine_map<()[s0] -> (s0)>
+//   CHECK-DAG: #[[$STRIDE1_MAP:.*]] = affine_map<()[s0, s1] -> (s0, s1, 42)>
 // CHECK-LABEL: func @extract_strided_metadata_of_collapse(
 //  CHECK-SAME: %[[ARG:.*]]: memref<?x?x4x?x6x7xi32>)
 //
@@ -1081,8 +1062,8 @@ func.func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride
 //
 //   CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MAP]]()[%[[STRIDES]]#0]
 //   CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3]
-//
-//       CHECK: return %[[BASE]], %[[C0]], %[[SIZES]]#0, %[[DYN_SIZE1]], %[[C42]], %[[DYN_STRIDE0]], %[[C42]], %[[C1]]
+//   CHECK-DAG:  %[[DYN_STRIDE1:.*]]  = affine.min #[[$STRIDE1_MAP]]()[%strides#1, %strides#2]
+//       CHECK: return %[[BASE]], %[[C0]], %[[SIZES]]#0, %[[DYN_SIZE1]], %[[C42]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[C1]]
 func.func @extract_strided_metadata_of_collapse(%arg : memref<?x?x4x?x6x7xi32>)
   -> (memref<i32>, index,
       index, index, index,
@@ -1383,12 +1364,9 @@ func.func @extract_strided_metadata_of_get_global_with_offset()
 //
 // CHECK-LABEL: func @extract_strided_metadata_of_cast
 //  CHECK-SAME: %[[ARG:.*]]: memref<3x?xi32, strided<[4, ?], offset: ?>>)
-//
-//   CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
-//   CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
 //       CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
 //
-//       CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[C3]], %[[DYN_SIZES]]#1, %[[C4]], %[[DYN_STRIDES]]#1
+//       CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[DYN_SIZES]]#0, %[[DYN_SIZES]]#1, %[[DYN_STRIDES]]#0, %[[DYN_STRIDES]]#1
 func.func @extract_strided_metadata_of_cast(
   %arg : memref<3x?xi32, strided<[4, ?], offset:?>>)
   -> (memref<i32>, index,

>From 704207c9e0fd358fdfe6078edb670c111452c93d Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Mon, 8 Apr 2024 15:52:24 -0600
Subject: [PATCH 2/2] reviwer comments and add back the tests

---
 mlir/test/Dialect/MemRef/canonicalize.mlir    | 30 +++++++++++++++++++
 .../MemRef/expand-strided-metadata.mlir       | 21 ++++++++++++-
 2 files changed, 50 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 0105c976661105..9211c2495f1a46 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -883,6 +883,36 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x
 
 // -----
 
+// Check that we don't simplify reinterpret cast of extract strided metadata
+// when the strides don't match.
+// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_stride
+//  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
+//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [4, 2, 2], strides: [1, 1, %[[STRIDES]]#1]
+//       CHECK: return %[[RES]]
+func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
+  %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
+  %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+  return %m2 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+}
+// -----
+
+// Check that we don't simplify reinterpret cast of extract strided metadata
+// when the offset doesn't match.
+// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset
+//  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
+//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [1], sizes: [%[[SIZES]]#0, %[[SIZES]]#1], strides: [%[[STRIDES]]#0, %[[STRIDES]]#1]
+//       CHECK: return %[[RES]]
+func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
+  %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
+  %m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  return %m2 : memref<?x?xf32, strided<[?, ?], offset: ?>>
+}
+
+// -----
+
+
 func.func @canonicalize_rank_reduced_subview(%arg0 : memref<8x?xf32>,
     %arg1 : index) -> memref<?xf32, strided<[?], offset: ?>> {
   %c0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index 70a5310cf9a45c..407443a91edafe 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -1,6 +1,25 @@
 // RUN: mlir-opt --expand-strided-metadata -split-input-file %s -o - | FileCheck %s
 
 
+// memref.extract_strided_metadata is not folded away becuase meta-data on the %base memref 
+// can change in a future pass and folding away the Op here will cuase incorrect lowering.
+// CHECK-LABEL: func @extract_strided_metadata_constants
+//  CHECK-SAME: (%[[ARG:.*]]: memref<5x4xf32, strided<[4, 1], offset: 2>>)
+func.func @extract_strided_metadata_constants(%base: memref<5x4xf32, strided<[4, 1], offset: 2>>)
+    -> (memref<f32>, index, index, index, index, index) {
+
+  //       CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
+  %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %base :
+    memref<5x4xf32, strided<[4,1], offset:2>>
+    -> memref<f32>, index, index, index, index, index
+
+  // CHECK: %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1
+  return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
+    memref<f32>, index, index, index, index, index
+}
+
+// -----
+
 // Check that we simplify subview(src) into:
 // base, offset, sizes, strides xtract_strided_metadata src
 // final_sizes = subSizes
@@ -1062,7 +1081,7 @@ func.func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride
 //
 //   CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MAP]]()[%[[STRIDES]]#0]
 //   CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3]
-//   CHECK-DAG:  %[[DYN_STRIDE1:.*]]  = affine.min #[[$STRIDE1_MAP]]()[%strides#1, %strides#2]
+//   CHECK-DAG:  %[[DYN_STRIDE1:.*]]  = affine.min #[[$STRIDE1_MAP]]()[%[[STRIDES]]#1, %[[STRIDES]]#2]
 //       CHECK: return %[[BASE]], %[[C0]], %[[SIZES]]#0, %[[DYN_SIZE1]], %[[C42]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[C1]]
 func.func @extract_strided_metadata_of_collapse(%arg : memref<?x?x4x?x6x7xi32>)
   -> (memref<i32>, index,



More information about the Mlir-commits mailing list