[Mlir-commits] [mlir] d665448 - [mlir][MemRef] Change the anchor point of a reshapeLikeOp pattern

Quentin Colombet llvmlistbot at llvm.org
Mon Nov 14 11:06:05 PST 2022


Author: Quentin Colombet
Date: 2022-11-14T18:56:35Z
New Revision: d665448a7f3ad88fac6c852703cf1f7baeb9200b

URL: https://github.com/llvm/llvm-project/commit/d665448a7f3ad88fac6c852703cf1f7baeb9200b
DIFF: https://github.com/llvm/llvm-project/commit/d665448a7f3ad88fac6c852703cf1f7baeb9200b.diff

LOG: [mlir][MemRef] Change the anchor point of a reshapeLikeOp pattern

Essentially, this patches changes the anchor point of the
`extract_strided_metadata(reshapeLikeOp)` pattern from
`extract_strided_metadata` to `reshapeLikeOp`.

In details, this means that instead of replacing:
```
base, offset, sizes, strides =
  extract_strided_metadata(reshapeLikeOp(src))
```
With
```
base, offset = extract_strided_metadata(src)
sizes = <some math>
strides = <some math>
```

We replace only the reshapeLikeOp part and connect it back with a
reinterpret_cast:
```
val = reshapeLikeOp(src)
```
=>
```
base, offset, ... = extract_strided_metadata(src)
sizes = <some math>
strides = <some math>
val = reinterpret_cast base, offset, sizes, strides

Differential Revision: https://reviews.llvm.org/D136386

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
    mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
index 3414538b71e33..bbf83575bd6c0 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
@@ -455,20 +455,15 @@ template <typename ReassociativeReshapeLikeOp,
               ReassociativeReshapeLikeOp, OpBuilder &,
               ArrayRef<OpFoldResult> /*origSizes*/,
               ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)>
-struct ExtractStridedMetadataOpReshapeFolder
-    : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
+struct ReshapeFolder : public OpRewritePattern<ReassociativeReshapeLikeOp> {
 public:
-  using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
+  using OpRewritePattern<ReassociativeReshapeLikeOp>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
+  LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape,
                                 PatternRewriter &rewriter) const override {
-    auto reshape = op.getSource().getDefiningOp<ReassociativeReshapeLikeOp>();
-    if (!reshape)
-      return failure();
-
     // Build a plain extract_strided_metadata(memref) from
     // extract_strided_metadata(reassociative_reshape_like(memref)).
-    Location origLoc = op.getLoc();
+    Location origLoc = reshape.getLoc();
     Value source = reshape.getSrc();
     auto sourceType = source.getType().cast<MemRefType>();
     unsigned sourceRank = sourceType.getRank();
@@ -487,27 +482,26 @@ struct ExtractStridedMetadataOpReshapeFolder
     MemRefType reshapeType = reshape.getResultType();
     unsigned reshapeRank = reshapeType.getRank();
 
-    // The result value will start with the base_buffer and offset.
-    unsigned baseIdxInResult = 2;
-    SmallVector<OpFoldResult> results(baseIdxInResult + reshapeRank * 2);
-    results[0] = newExtractStridedMetadata.getBaseBuffer();
-    results[1] = ShapedType::isDynamicStrideOrOffset(offset)
-                     ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
-                     : rewriter.getIndexAttr(offset);
+    OpFoldResult offsetOfr =
+        ShapedType::isDynamicStrideOrOffset(offset)
+            ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
+            : rewriter.getIndexAttr(offset);
 
     // Get the special case of 0-D out of the way.
     if (sourceRank == 0) {
-      Value constantOne = getValueOrCreateConstantIndexOp(
-          rewriter, origLoc, rewriter.getIndexAttr(1));
-      SmallVector<Value> resultValues(baseIdxInResult + reshapeRank * 2,
-                                      constantOne);
-      for (unsigned i = 0; i < baseIdxInResult; ++i)
-        resultValues[i] =
-            getValueOrCreateConstantIndexOp(rewriter, origLoc, results[i]);
-      rewriter.replaceOp(op, resultValues);
+      SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1));
+      auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
+          origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
+          offsetOfr, /*sizes=*/ones, /*strides=*/ones);
+      rewriter.replaceOp(reshape, memrefDesc.getResult());
       return success();
     }
 
+    SmallVector<OpFoldResult> finalSizes;
+    finalSizes.reserve(reshapeRank);
+    SmallVector<OpFoldResult> finalStrides;
+    finalStrides.reserve(reshapeRank);
+
     // Compute the reshaped strides and sizes from the base strides and sizes.
     SmallVector<OpFoldResult> origSizes =
         getAsOpFoldResult(newExtractStridedMetadata.getSizes());
@@ -521,21 +515,20 @@ struct ExtractStridedMetadataOpReshapeFolder
           reshape, rewriter, origSizes, origStrides, /*groupId=*/idx);
 
       unsigned groupSize = reshapedSizes.size();
-      const unsigned sizeStartIdx = baseIdxInResult;
-      const unsigned strideStartIdx = sizeStartIdx + reshapeRank;
       for (unsigned i = 0; i < groupSize; ++i) {
-        results[sizeStartIdx + i] = reshapedSizes[i];
-        results[strideStartIdx + i] = reshapedStrides[i];
+        finalSizes.push_back(reshapedSizes[i]);
+        finalStrides.push_back(reshapedStrides[i]);
       }
-      baseIdxInResult += groupSize;
     }
     assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
             (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
            "We should have visited all the input dimensions");
-    assert(baseIdxInResult == reshapeRank + 2 &&
+    assert(finalSizes.size() == reshapeRank &&
            "We should have populated all the values");
-    rewriter.replaceOp(
-        op, getValueOrCreateConstantIndexOp(rewriter, origLoc, results));
+    auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
+        origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
+        offsetOfr, finalSizes, finalStrides);
+    rewriter.replaceOp(reshape, memrefDesc.getResult());
     return success();
   }
 };
@@ -745,18 +738,17 @@ class ExtractStridedMetadataOpExtractStridedMetadataFolder
 
 void memref::populateSimplifyExtractStridedMetadataOpPatterns(
     RewritePatternSet &patterns) {
-  patterns
-      .add<SubviewFolder,
-           ExtractStridedMetadataOpReshapeFolder<
-               memref::ExpandShapeOp, getExpandedSizes, getExpandedStrides>,
-           ExtractStridedMetadataOpReshapeFolder<
-               memref::CollapseShapeOp, getCollapsedSize, getCollapsedStride>,
-           ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
-           ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
-           RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
-           ExtractStridedMetadataOpReinterpretCastFolder,
-           ExtractStridedMetadataOpExtractStridedMetadataFolder>(
-          patterns.getContext());
+  patterns.add<SubviewFolder,
+               ReshapeFolder<memref::ExpandShapeOp, getExpandedSizes,
+                             getExpandedStrides>,
+               ReshapeFolder<memref::CollapseShapeOp, getCollapsedSize,
+                             getCollapsedStride>,
+               ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
+               ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
+               RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
+               ExtractStridedMetadataOpReinterpretCastFolder,
+               ExtractStridedMetadataOpExtractStridedMetadataFolder>(
+      patterns.getContext());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
index 3f312df10e214..d7f2dfdb77b31 100644
--- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
@@ -352,6 +352,88 @@ func.func @extract_strided_metadata_of_subview_all_dynamic(
 
 // -----
 
+// Check that we properly simplify expand_shape into:
+// reinterpret_cast(extract_strided_metadata) + <some math>
+//
+// Here we have:
+// For the group applying to dim0:
+// size 0 = baseSizes#0 / (all static sizes in that group)
+//        = baseSizes#0 / (7 * 8 * 9)
+//        = baseSizes#0 / 504
+// size 1 = 7
+// size 2 = 8
+// size 3 = 9
+// stride 0 = baseStrides#0 * 7 * 8 * 9
+//          = baseStrides#0 * 504
+// stride 1 = baseStrides#0 * 8 * 9
+//          = baseStrides#0 * 72
+// stride 2 = baseStrides#0 * 9
+// stride 3 = baseStrides#0
+//
+// For the group applying to dim1:
+// size 4 = 10
+// size 5 = 2
+// size 6 = baseSizes#1 / (all static sizes in that group)
+//        = baseSizes#1 / (10 * 2 * 3)
+//        = baseSizes#1 / 60
+// size 7 = 3
+// stride 4 = baseStrides#1 * size 5 * size 6 * size 7
+//          = baseStrides#1 * 2 * (baseSizes#1 / 60) * 3
+//          = baseStrides#1 * (baseSizes#1 / 60) * 6
+//          and since we know that baseSizes#1 is a multiple of 60:
+//          = baseStrides#1 * (baseSizes#1 / 10)
+// stride 5 = baseStrides#1 * size 6 * size 7
+//          = baseStrides#1 * (baseSizes#1 / 60) * 3
+//          = baseStrides#1 * (baseSizes#1 / 20)
+// stride 6 = baseStrides#1 * size 7
+//          = baseStrides#1 * 3
+// stride 7 = baseStrides#1
+//
+// Base and offset are unchanged.
+//
+//   CHECK-DAG: #[[$DIM0_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 504)>
+//   CHECK-DAG: #[[$DIM6_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 60)>
+//
+//   CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 504)>
+//   CHECK-DAG: #[[$DIM1_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 72)>
+//   CHECK-DAG: #[[$DIM2_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 9)>
+//   CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 10) * s1)>
+//   CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 20) * s1)>
+//   CHECK-DAG: #[[$DIM6_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 3)>
+// CHECK-LABEL: func @simplify_expand_shape
+//  CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32,
+//
+//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> -> memref<f32>, index, index, index, index, index
+//
+//   CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$DIM0_SIZE_MAP]]()[%[[SIZES]]#0]
+//   CHECK-DAG: %[[DYN_SIZE6:.*]] = affine.apply #[[$DIM6_SIZE_MAP]]()[%[[SIZES]]#1]
+//   CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[STRIDES]]#0]
+//   CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #[[$DIM1_STRIDE_MAP]]()[%[[STRIDES]]#0]
+//   CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #[[$DIM2_STRIDE_MAP]]()[%[[STRIDES]]#0]
+//   CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
+//   CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
+//   CHECK-DAG: %[[DYN_STRIDE6:.*]] = affine.apply #[[$DIM6_STRIDE_MAP]]()[%[[STRIDES]]#1]
+//
+//   CHECK-DAG: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [%[[DYN_SIZE0]], 7, 8, 9, 10, 2, %[[DYN_SIZE6]], 3], strides: [%[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1]
+//
+//   CHECK: return %[[REINTERPRET_CAST]]
+func.func @simplify_expand_shape(
+    %base: memref<?x?xf32, strided<[?,?], offset:?>>,
+    %offset0: index, %offset1: index, %offset2: index,
+    %size0: index, %size1: index, %size2: index,
+    %stride0: index, %stride1: index, %stride2: index)
+    -> memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>> {
+
+  %subview = memref.expand_shape %base[[0, 1, 2, 3],[4, 5, 6, 7]] :
+    memref<?x?xf32, strided<[?,?], offset: ?>> into
+      memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
+
+  return %subview :
+    memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
+}
+
+// -----
+
 // Check that we properly simplify extract_strided_metadata of expand_shape
 // into:
 // baseBuffer, baseOffset, baseSizes, baseStrides =
@@ -815,6 +897,43 @@ func.func @extract_aligned_pointer_as_index(%arg0: memref<f32>) -> index {
 
 // -----
 
+// Check that we simplify collapse_shape into
+// reinterpret_cast(extract_strided_metadata) + <some math>
+//
+// We transform: ?x?x4x?x6x7xi32 to [0][1,2,3][4,5]
+// Size 0 = origSize0
+// Size 1 = origSize1 * origSize2 * origSize3
+//        = origSize1 * 4 * origSize3
+// Size 2 = origSize4 * origSize5
+//        = 6 * 7
+//        = 42
+// Stride 0 = origStride0
+// Stride 1 = origStride3 (orig stride of the inner most dimension)
+//          = 42
+// Stride 2 = origStride5
+//          = 1
+//
+//   CHECK-DAG: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
+// CHECK-LABEL: func @simplify_collapse(
+//  CHECK-SAME: %[[ARG:.*]]: memref<?x?x4x?x6x7xi32>)
+//
+//       CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref<?x?x4x?x6x7xi32>
+//
+//       CHECK: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3]
+//
+//       CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [%[[SIZES]]#0, %[[DYN_SIZE1]], 42], strides: [%[[STRIDES]]#0, 42, 1]
+func.func @simplify_collapse(%arg : memref<?x?x4x?x6x7xi32>)
+  -> memref<?x?x42xi32> {
+
+  %collapsed_view = memref.collapse_shape %arg [[0], [1, 2, 3], [4, 5]] :
+    memref<?x?x4x?x6x7xi32> into memref<?x?x42xi32>
+
+  return %collapsed_view : memref<?x?x42xi32>
+
+}
+
+// -----
+
 // Check that we simplify extract_strided_metadata of collapse_shape.
 //
 // We transform: ?x?x4x?x6x7xi32 to [0][1,2,3][4,5]


        


More information about the Mlir-commits mailing list