[Mlir-commits] [mlir] [mlir][MemRef] Add ExtractStridedMetadataOpCollapseShapeFolder (PR #89954)

Diego Caballero llvmlistbot at llvm.org
Fri Apr 26 03:00:31 PDT 2024


https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/89954

>From aaf3a95d0e6c52f7ced2d054ca35adb52b15a099 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Wed, 24 Apr 2024 17:07:46 +0000
Subject: [PATCH 1/3] [mlir][MemRef] Add
 ExtractStridedMetadataOpCollapseShapeFolder

This PR adds a new pattern to the set of patterns used to resolve the
offset, sizes and stride of a memref. Similar to `ExtractStridedMetadataOpSubviewFolder`,
the new pattern resolves strided_metadata(collapse_shape) directly,
without introduce a reshape_cast op.
---
 .../Transforms/ExpandStridedMetadata.cpp      | 189 ++++++++++++------
 1 file changed, 130 insertions(+), 59 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 96eb7cfd2db690..b5578a58468e9c 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -550,6 +550,78 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
   return {makeComposedFoldedAffineMin(builder, collapseShape.getLoc(), minMap,
                                       groupStrides)};
 }
+
+template <typename ReassociativeReshapeLikeOp,
+          SmallVector<OpFoldResult> (*getReshapedSizes)(
+              ReassociativeReshapeLikeOp, OpBuilder &,
+              ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/),
+          SmallVector<OpFoldResult> (*getReshapedStrides)(
+              ReassociativeReshapeLikeOp, OpBuilder &,
+              ArrayRef<OpFoldResult> /*origSizes*/,
+              ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)>
+static FailureOr<StridedMetadata>
+resolveReshapeStridedMetadata(RewriterBase &rewriter,
+                              ReassociativeReshapeLikeOp reshape) {
+  // Build a plain extract_strided_metadata(memref) from
+  // extract_strided_metadata(reassociative_reshape_like(memref)).
+  Location origLoc = reshape.getLoc();
+  Value source = reshape.getSrc();
+  auto sourceType = cast<MemRefType>(source.getType());
+  unsigned sourceRank = sourceType.getRank();
+
+  auto newExtractStridedMetadata =
+      rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
+
+  // Collect statically known information.
+  auto [strides, offset] = getStridesAndOffset(sourceType);
+  MemRefType reshapeType = reshape.getResultType();
+  unsigned reshapeRank = reshapeType.getRank();
+
+  OpFoldResult offsetOfr =
+      ShapedType::isDynamic(offset)
+          ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
+          : rewriter.getIndexAttr(offset);
+
+  // Get the special case of 0-D out of the way.
+  if (sourceRank == 0) {
+    SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1));
+    return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
+                           /*sizes=*/ones, /*strides=*/ones};
+  }
+
+  SmallVector<OpFoldResult> finalSizes;
+  finalSizes.reserve(reshapeRank);
+  SmallVector<OpFoldResult> finalStrides;
+  finalStrides.reserve(reshapeRank);
+
+  // Compute the reshaped strides and sizes from the base strides and sizes.
+  SmallVector<OpFoldResult> origSizes =
+      getAsOpFoldResult(newExtractStridedMetadata.getSizes());
+  SmallVector<OpFoldResult> origStrides =
+      getAsOpFoldResult(newExtractStridedMetadata.getStrides());
+  unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
+  for (; idx != endIdx; ++idx) {
+    SmallVector<OpFoldResult> reshapedSizes =
+        getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx);
+    SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides(
+        reshape, rewriter, origSizes, origStrides, /*groupId=*/idx);
+
+    unsigned groupSize = reshapedSizes.size();
+    for (unsigned i = 0; i < groupSize; ++i) {
+      finalSizes.push_back(reshapedSizes[i]);
+      finalStrides.push_back(reshapedStrides[i]);
+    }
+  }
+  assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
+          (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
+         "We should have visited all the input dimensions");
+  assert(finalSizes.size() == reshapeRank &&
+         "We should have populated all the values");
+
+  return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
+                         finalSizes, finalStrides};
+}
+
 /// Replace `baseBuffer, offset, sizes, strides =
 ///              extract_strided_metadata(reshapeLike(memref))`
 /// With
@@ -580,68 +652,66 @@ struct ReshapeFolder : public OpRewritePattern<ReassociativeReshapeLikeOp> {
 
   LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape,
                                 PatternRewriter &rewriter) const override {
-    // Build a plain extract_strided_metadata(memref) from
-    // extract_strided_metadata(reassociative_reshape_like(memref)).
-    Location origLoc = reshape.getLoc();
-    Value source = reshape.getSrc();
-    auto sourceType = cast<MemRefType>(source.getType());
-    unsigned sourceRank = sourceType.getRank();
-
-    auto newExtractStridedMetadata =
-        rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
-
-    // Collect statically known information.
-    auto [strides, offset] = getStridesAndOffset(sourceType);
-    MemRefType reshapeType = reshape.getResultType();
-    unsigned reshapeRank = reshapeType.getRank();
-
-    OpFoldResult offsetOfr =
-        ShapedType::isDynamic(offset)
-            ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
-            : rewriter.getIndexAttr(offset);
-
-    // Get the special case of 0-D out of the way.
-    if (sourceRank == 0) {
-      SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1));
-      auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
-          origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
-          offsetOfr, /*sizes=*/ones, /*strides=*/ones);
-      rewriter.replaceOp(reshape, memrefDesc.getResult());
-      return success();
+    FailureOr<StridedMetadata> stridedMetadata =
+        resolveReshapeStridedMetadata<ReassociativeReshapeLikeOp,
+                                      getReshapedSizes, getReshapedStrides>(
+            rewriter, reshape);
+    if (failed(stridedMetadata)) {
+      return rewriter.notifyMatchFailure(reshape,
+                                         "failed to resolve reshape metadata");
     }
 
-    SmallVector<OpFoldResult> finalSizes;
-    finalSizes.reserve(reshapeRank);
-    SmallVector<OpFoldResult> finalStrides;
-    finalStrides.reserve(reshapeRank);
-
-    // Compute the reshaped strides and sizes from the base strides and sizes.
-    SmallVector<OpFoldResult> origSizes =
-        getAsOpFoldResult(newExtractStridedMetadata.getSizes());
-    SmallVector<OpFoldResult> origStrides =
-        getAsOpFoldResult(newExtractStridedMetadata.getStrides());
-    unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
-    for (; idx != endIdx; ++idx) {
-      SmallVector<OpFoldResult> reshapedSizes =
-          getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx);
-      SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides(
-          reshape, rewriter, origSizes, origStrides, /*groupId=*/idx);
-
-      unsigned groupSize = reshapedSizes.size();
-      for (unsigned i = 0; i < groupSize; ++i) {
-        finalSizes.push_back(reshapedSizes[i]);
-        finalStrides.push_back(reshapedStrides[i]);
-      }
+    rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+        reshape, reshape.getType(), stridedMetadata->basePtr,
+        stridedMetadata->offset, stridedMetadata->sizes,
+        stridedMetadata->strides);
+    return success();
+  }
+};
+
+/// Pattern to replace `extract_strided_metadata(collapse_shape)`
+/// With
+///
+/// \verbatim
+/// baseBuffer, baseOffset, baseSizes, baseStrides =
+///     extract_strided_metadata(memref)
+/// strides#i = baseStrides#i * subSizes#i
+/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
+/// sizes = subSizes
+/// \verbatim
+///
+/// with `baseBuffer`, `offset`, `sizes` and `strides` being
+/// the replacements for the original `extract_strided_metadata`.
+struct ExtractStridedMetadataOpCollapseShapeFolder
+    : OpRewritePattern<memref::ExtractStridedMetadataOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
+                                PatternRewriter &rewriter) const override {
+    auto collapseShapeOp =
+        op.getSource().getDefiningOp<memref::CollapseShapeOp>();
+    if (!collapseShapeOp)
+      return failure();
+
+    FailureOr<StridedMetadata> stridedMetadata =
+        resolveReshapeStridedMetadata<memref::CollapseShapeOp, getCollapsedSize,
+                                      getCollapsedStride>(rewriter,
+                                                          collapseShapeOp);
+    if (failed(stridedMetadata)) {
+      return rewriter.notifyMatchFailure(
+          op, "failed to resolve metadata in terms of source collapse_shape op");
     }
-    assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
-            (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
-           "We should have visited all the input dimensions");
-    assert(finalSizes.size() == reshapeRank &&
-           "We should have populated all the values");
-    auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
-        origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
-        offsetOfr, finalSizes, finalStrides);
-    rewriter.replaceOp(reshape, memrefDesc.getResult());
+
+    Location loc = collapseShapeOp.getLoc();
+    SmallVector<Value> results;
+    results.push_back(stridedMetadata->basePtr);
+    results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
+                                                      stridedMetadata->offset));
+    results.append(
+        getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
+    results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
+                                                   stridedMetadata->strides));
+    rewriter.replaceOp(op, results);
     return success();
   }
 };
@@ -1030,6 +1100,7 @@ void memref::populateResolveExtractStridedMetadataPatterns(
     RewritePatternSet &patterns) {
   patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
                ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
+               ExtractStridedMetadataOpCollapseShapeFolder,
                ExtractStridedMetadataOpGetGlobalFolder,
                ExtractStridedMetadataOpSubviewFolder,
                RewriteExtractAlignedPointerAsIndexOfViewLikeOp,

>From 19c46b519f9a64e92ad9638d04e2a0ba528b96bf Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Fri, 26 Apr 2024 09:55:40 +0000
Subject: [PATCH 2/3] Review feedback

- Add test
- Add doc
- Use function_ref
---
 .../Transforms/ExpandStridedMetadata.cpp      | 38 +++++++++++++------
 .../MemRef/expand-strided-metadata.mlir       | 24 +++++++++++-
 2 files changed, 49 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index b5578a58468e9c..479646756cb5df 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -551,17 +551,28 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
                                       groupStrides)};
 }
 
-template <typename ReassociativeReshapeLikeOp,
-          SmallVector<OpFoldResult> (*getReshapedSizes)(
-              ReassociativeReshapeLikeOp, OpBuilder &,
-              ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/),
-          SmallVector<OpFoldResult> (*getReshapedStrides)(
-              ReassociativeReshapeLikeOp, OpBuilder &,
-              ArrayRef<OpFoldResult> /*origSizes*/,
-              ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)>
-static FailureOr<StridedMetadata>
-resolveReshapeStridedMetadata(RewriterBase &rewriter,
-                              ReassociativeReshapeLikeOp reshape) {
+/// From `reshape_like(memref, subSizes, subStrides))` compute
+///
+/// \verbatim
+/// baseBuffer, baseOffset, baseSizes, baseStrides =
+///     extract_strided_metadata(memref)
+/// strides#i = baseStrides#i * subStrides#i
+/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
+/// sizes = subSizes
+/// \endverbatim
+///
+/// and return {baseBuffer, offset, sizes, strides}
+template <typename ReassociativeReshapeLikeOp>
+static FailureOr<StridedMetadata> resolveReshapeStridedMetadata(
+    RewriterBase &rewriter, ReassociativeReshapeLikeOp reshape,
+    function_ref<SmallVector<OpFoldResult>(
+        ReassociativeReshapeLikeOp, OpBuilder &,
+        ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/)>
+        getReshapedSizes,
+    function_ref<SmallVector<OpFoldResult>(
+        ReassociativeReshapeLikeOp, OpBuilder &,
+        ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/)>
+        getReshapedStrides) {
   // Build a plain extract_strided_metadata(memref) from
   // extract_strided_metadata(reassociative_reshape_like(memref)).
   Location origLoc = reshape.getLoc();
@@ -699,7 +710,8 @@ struct ExtractStridedMetadataOpCollapseShapeFolder
                                                           collapseShapeOp);
     if (failed(stridedMetadata)) {
       return rewriter.notifyMatchFailure(
-          op, "failed to resolve metadata in terms of source collapse_shape op");
+          op,
+          "failed to resolve metadata in terms of source collapse_shape op");
     }
 
     Location loc = collapseShapeOp.getLoc();
@@ -1088,9 +1100,11 @@ void memref::populateExpandStridedMetadataPatterns(
                              getCollapsedStride>,
                ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
                ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
+               ExtractStridedMetadataOpCollapseShapeFolder,
                ExtractStridedMetadataOpGetGlobalFolder,
                RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
                ExtractStridedMetadataOpReinterpretCastFolder,
+               ExtractStridedMetadataOpSubviewFolder,
                ExtractStridedMetadataOpCastFolder,
                ExtractStridedMetadataOpExtractStridedMetadataFolder>(
       patterns.getContext());
diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index 28b70043005940..0705b30ca45d86 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -1513,4 +1513,26 @@ func.func @zero_sized_memred(%arg0: f32) -> (memref<f16, 3>, index,index,index)
     %sizes, %strides :
       memref<f16,3>, index,
       index, index
-}
\ No newline at end of file
+}
+
+// -----
+
+func.func @extract_strided_metadata_of_collapse_shape(%base: memref<5x4xf32>)
+    -> (memref<f32>, index, index, index) {
+
+  %collapse = memref.collapse_shape %base[[0, 1]] :
+    memref<5x4xf32> into memref<20xf32>
+
+  %base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %collapse :
+    memref<20xf32> -> memref<f32>, index, index, index
+
+  return %base_buffer, %offset, %size, %stride :
+    memref<f32>, index, index, index
+}
+
+// CHECK-LABEL:  func @extract_strided_metadata_of_collapse_shape
+//   CHECK-DAG:    %[[OFFSET:.*]] = arith.constant 0 : index
+//   CHECK-DAG:    %[[SIZE:.*]] = arith.constant 20 : index
+//   CHECK-DAG:    %[[STEP:.*]] = arith.constant 1 : index
+//       CHECK:    %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata
+//       CHECK:    return %[[BASE]], %[[OFFSET]], %[[SIZE]], %[[STEP]] : memref<f32>, index, index, index

>From 3500630fc7a90164dd5180e4696fbd901b083545 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Fri, 26 Apr 2024 10:00:15 +0000
Subject: [PATCH 3/3] Fix doc

---
 mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 479646756cb5df..999b50e25ca8f1 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -557,11 +557,10 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
 /// baseBuffer, baseOffset, baseSizes, baseStrides =
 ///     extract_strided_metadata(memref)
 /// strides#i = baseStrides#i * subStrides#i
-/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
 /// sizes = subSizes
 /// \endverbatim
 ///
-/// and return {baseBuffer, offset, sizes, strides}
+/// and return {baseBuffer, baseOffset, sizes, strides}
 template <typename ReassociativeReshapeLikeOp>
 static FailureOr<StridedMetadata> resolveReshapeStridedMetadata(
     RewriterBase &rewriter, ReassociativeReshapeLikeOp reshape,



More information about the Mlir-commits mailing list