[Mlir-commits] [mlir] 4178366 - [mlir][MemRef] Change the anchor point of a subview pattern

Quentin Colombet llvmlistbot at llvm.org
Mon Nov 14 10:54:46 PST 2022


Author: Quentin Colombet
Date: 2022-11-14T18:43:34Z
New Revision: 41783666e4ae958e85db8e7ef04bea4ab909ab9e

URL: https://github.com/llvm/llvm-project/commit/41783666e4ae958e85db8e7ef04bea4ab909ab9e
DIFF: https://github.com/llvm/llvm-project/commit/41783666e4ae958e85db8e7ef04bea4ab909ab9e.diff

LOG: [mlir][MemRef] Change the anchor point of a subview pattern

Essentially, this patches changes the anchor point of the
`extract_strided_metadata(subview)` pattern from
`extract_strided_metadata` to `subview`.

In details, this means that instead of replacing:
```
base, offset, sizes, strides = extract_strided_metadata(subview(src))
```
With
```
base, ... = extract_strided_metadata(src)
offset = <some math>
sizes = subSizes
strides = <some math>
```

We replace only the subview part and connect it back with a
reinterpret_cast:
```
val = subview(src)
```
=>
```
base, ... = extract_strided_metadata(src)
offset = <some math>
sizes = subSizes
strides = <some math>
val = reinterpret_cast base, offset, sizes, strides
```

Differential Revision: https://reviews.llvm.org/D135839

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
    mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
index 5a95e7ee668d6..3414538b71e33 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
@@ -30,9 +30,7 @@ namespace memref {
 using namespace mlir;
 
 namespace {
-/// Replace `baseBuffer, offset, sizes, strides =
-///              extract_strided_metadata(subview(memref, subOffset,
-///                                               subSizes, subStrides))`
+/// Replace `dst = subview(memref, subOffset, subSizes, subStrides))`
 /// With
 ///
 /// \verbatim
@@ -41,24 +39,19 @@ namespace {
 /// strides#i = baseStrides#i * subSizes#i
 /// offset = baseOffset + sum(subOffset#i * baseStrides#i)
 /// sizes = subSizes
+/// dst = reinterpret_cast baseBuffer, offset, sizes, strides
 /// \endverbatim
 ///
 /// In other words, get rid of the subview in that expression and canonicalize
 /// on its effects on the offset, the sizes, and the strides using affine.apply.
-struct ExtractStridedMetadataOpSubviewFolder
-    : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
+struct SubviewFolder : public OpRewritePattern<memref::SubViewOp> {
 public:
-  using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
+  using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
+  LogicalResult matchAndRewrite(memref::SubViewOp subview,
                                 PatternRewriter &rewriter) const override {
-    auto subview = op.getSource().getDefiningOp<memref::SubViewOp>();
-    if (!subview)
-      return failure();
-
-    // Build a plain extract_strided_metadata(memref) from
-    // extract_strided_metadata(subview(memref)).
-    Location origLoc = op.getLoc();
+    // Build a plain extract_strided_metadata(memref) from subview(memref).
+    Location origLoc = subview.getLoc();
     Value source = subview.getSource();
     auto sourceType = source.getType().cast<MemRefType>();
     unsigned sourceRank = sourceType.getRank();
@@ -117,20 +110,11 @@ struct ExtractStridedMetadataOpSubviewFolder
     OpFoldResult finalOffset =
         makeComposedFoldedAffineApply(rewriter, origLoc, expr, values);
 
-    SmallVector<Value> results;
     // The final result is  <baseBuffer, offset, sizes, strides>.
     // Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all
     // the values.
     auto subType = subview.getType().cast<MemRefType>();
     unsigned subRank = subType.getRank();
-    // Properly size the array so that we can do random insertions
-    // at the right indices.
-    // We do that to populate the non-dropped sizes and strides in one go.
-    results.resize_for_overwrite(subRank * 2 + 2);
-
-    results[0] = newExtractStridedMetadata.getBaseBuffer();
-    results[1] =
-        getValueOrCreateConstantIndexOp(rewriter, origLoc, finalOffset);
 
     // The sizes of the final type are defined directly by the input sizes of
     // the subview.
@@ -139,24 +123,30 @@ struct ExtractStridedMetadataOpSubviewFolder
     // replacing.
     // Do the filtering here.
     SmallVector<OpFoldResult> subSizes = subview.getMixedSizes();
-    const unsigned sizeStartIdx = 2;
-    const unsigned strideStartIdx = sizeStartIdx + subRank;
-    unsigned insertedDims = 0;
     llvm::SmallBitVector droppedDims = subview.getDroppedDims();
+
+    SmallVector<OpFoldResult> finalSizes;
+    finalSizes.reserve(subRank);
+
+    SmallVector<OpFoldResult> finalStrides;
+    finalStrides.reserve(subRank);
+
     for (unsigned i = 0; i < sourceRank; ++i) {
       if (droppedDims.test(i))
         continue;
 
-      results[sizeStartIdx + insertedDims] =
-          getValueOrCreateConstantIndexOp(rewriter, origLoc, subSizes[i]);
-      results[strideStartIdx + insertedDims] =
-          getValueOrCreateConstantIndexOp(rewriter, origLoc, strides[i]);
-      ++insertedDims;
+      finalSizes.push_back(subSizes[i]);
+      finalStrides.push_back(strides[i]);
     }
-    assert(insertedDims == subRank &&
+    assert(finalSizes.size() == subRank &&
            "Should have populated all the values at this point");
 
-    rewriter.replaceOp(op, results);
+    auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
+        origLoc, subType, newExtractStridedMetadata.getBaseBuffer(),
+        finalOffset,
+        /*sizes=*/finalSizes,
+        /*strides=*/finalStrides);
+    rewriter.replaceOp(subview, memrefDesc.getResult());
     return success();
   }
 };
@@ -756,7 +746,7 @@ class ExtractStridedMetadataOpExtractStridedMetadataFolder
 void memref::populateSimplifyExtractStridedMetadataOpPatterns(
     RewritePatternSet &patterns) {
   patterns
-      .add<ExtractStridedMetadataOpSubviewFolder,
+      .add<SubviewFolder,
            ExtractStridedMetadataOpReshapeFolder<
                memref::ExpandShapeOp, getExpandedSizes, getExpandedStrides>,
            ExtractStridedMetadataOpReshapeFolder<

diff  --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
index b6661ee1b5dd5..3f312df10e214 100644
--- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
@@ -21,6 +21,57 @@ func.func @extract_strided_metadata_constants(%base: memref<5x4xf32, strided<[4,
 
 // -----
 
+// Check that we simplify subview(src) into:
+// base, offset, sizes, strides xtract_strided_metadata src
+// final_sizes = subSizes
+// final_strides = <some math> strides
+// final_offset = <some math> offset
+// reinterpret_cast base to final_offset, final_sizes, final_ strides
+//
+// Orig strides: [s0, s1, s2]
+// Sub strides: [subS0, subS1, subS2]
+// => New strides: [s0 * subS0, s1 * subS1, s2 * subS2]
+// ==> 1 affine map (used for each stride) with two values.
+//
+// Orig offset: origOff
+// Sub offsets: [subO0, subO1, subO2]
+// => Final offset: s0 * * subO0 + ... + s2 * * subO2 + origOff
+// ==> 1 affine map with (rank * 2 + 1) symbols
+//
+// CHECK-DAG: #[[$STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
+// CHECK-DAG: #[[$OFFSET_MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s0 + s1 * s2 + s3 * s4 + s5 * s6)>
+// CHECK-LABEL: func @simplify_subview_all_dynamic
+//  CHECK-SAME: (%[[ARG:.*]]: memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>, %[[DYN_OFFSET0:.*]]: index, %[[DYN_OFFSET1:.*]]: index, %[[DYN_OFFSET2:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_SIZE2:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index, %[[DYN_STRIDE2:.*]]: index)
+//
+//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]]
+//
+//  CHECK-DAG: %[[FINAL_STRIDE0:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE0]], %[[STRIDES]]#0]
+//  CHECK-DAG: %[[FINAL_STRIDE1:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE1]], %[[STRIDES]]#1]
+//  CHECK-DAG: %[[FINAL_STRIDE2:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE2]], %[[STRIDES]]#2]
+//
+//  CHECK-DAG: %[[FINAL_OFFSET:.*]] = affine.apply #[[$OFFSET_MAP]]()[%[[OFFSET]], %[[DYN_OFFSET0]], %[[STRIDES]]#0, %[[DYN_OFFSET1]], %[[STRIDES]]#1, %[[DYN_OFFSET2]], %[[STRIDES]]#2]
+//
+//      CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[FINAL_OFFSET]]], sizes: [%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]]], strides: [%[[FINAL_STRIDE0]], %[[FINAL_STRIDE1]], %[[FINAL_STRIDE2]]]
+//
+//       CHECK: return %[[RES]]
+func.func @simplify_subview_all_dynamic(
+    %base: memref<?x?x?xf32, strided<[?,?,?], offset:?>>,
+    %offset0: index, %offset1: index, %offset2: index,
+    %size0: index, %size1: index, %size2: index,
+    %stride0: index, %stride1: index, %stride2: index)
+    -> memref<?x?x?xf32, strided<[?,?,?], offset:?>> {
+
+  %subview = memref.subview %base[%offset0, %offset1, %offset2]
+                                 [%size0, %size1, %size2]
+                                 [%stride0, %stride1, %stride2] :
+    memref<?x?x?xf32, strided<[?,?,?], offset: ?>> to
+      memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+
+  return %subview : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+}
+
+// -----
+
 // Check that we simplify extract_strided_metadata of subview to
 // base_buf, base_offset, base_sizes, base_strides = extract_strided_metadata
 // strides = base_stride_i * subview_stride_i


        


More information about the Mlir-commits mailing list