[Mlir-commits] [mlir] 4178366 - [mlir][MemRef] Change the anchor point of a subview pattern
Quentin Colombet
llvmlistbot at llvm.org
Mon Nov 14 10:54:46 PST 2022
Author: Quentin Colombet
Date: 2022-11-14T18:43:34Z
New Revision: 41783666e4ae958e85db8e7ef04bea4ab909ab9e
URL: https://github.com/llvm/llvm-project/commit/41783666e4ae958e85db8e7ef04bea4ab909ab9e
DIFF: https://github.com/llvm/llvm-project/commit/41783666e4ae958e85db8e7ef04bea4ab909ab9e.diff
LOG: [mlir][MemRef] Change the anchor point of a subview pattern
Essentially, this patches changes the anchor point of the
`extract_strided_metadata(subview)` pattern from
`extract_strided_metadata` to `subview`.
In details, this means that instead of replacing:
```
base, offset, sizes, strides = extract_strided_metadata(subview(src))
```
With
```
base, ... = extract_strided_metadata(src)
offset = <some math>
sizes = subSizes
strides = <some math>
```
We replace only the subview part and connect it back with a
reinterpret_cast:
```
val = subview(src)
```
=>
```
base, ... = extract_strided_metadata(src)
offset = <some math>
sizes = subSizes
strides = <some math>
val = reinterpret_cast base, offset, sizes, strides
```
Differential Revision: https://reviews.llvm.org/D135839
Added:
Modified:
mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
index 5a95e7ee668d6..3414538b71e33 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
@@ -30,9 +30,7 @@ namespace memref {
using namespace mlir;
namespace {
-/// Replace `baseBuffer, offset, sizes, strides =
-/// extract_strided_metadata(subview(memref, subOffset,
-/// subSizes, subStrides))`
+/// Replace `dst = subview(memref, subOffset, subSizes, subStrides))`
/// With
///
/// \verbatim
@@ -41,24 +39,19 @@ namespace {
/// strides#i = baseStrides#i * subSizes#i
/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
/// sizes = subSizes
+/// dst = reinterpret_cast baseBuffer, offset, sizes, strides
/// \endverbatim
///
/// In other words, get rid of the subview in that expression and canonicalize
/// on its effects on the offset, the sizes, and the strides using affine.apply.
-struct ExtractStridedMetadataOpSubviewFolder
- : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
+struct SubviewFolder : public OpRewritePattern<memref::SubViewOp> {
public:
- using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
+ using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
+ LogicalResult matchAndRewrite(memref::SubViewOp subview,
PatternRewriter &rewriter) const override {
- auto subview = op.getSource().getDefiningOp<memref::SubViewOp>();
- if (!subview)
- return failure();
-
- // Build a plain extract_strided_metadata(memref) from
- // extract_strided_metadata(subview(memref)).
- Location origLoc = op.getLoc();
+ // Build a plain extract_strided_metadata(memref) from subview(memref).
+ Location origLoc = subview.getLoc();
Value source = subview.getSource();
auto sourceType = source.getType().cast<MemRefType>();
unsigned sourceRank = sourceType.getRank();
@@ -117,20 +110,11 @@ struct ExtractStridedMetadataOpSubviewFolder
OpFoldResult finalOffset =
makeComposedFoldedAffineApply(rewriter, origLoc, expr, values);
- SmallVector<Value> results;
// The final result is <baseBuffer, offset, sizes, strides>.
// Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all
// the values.
auto subType = subview.getType().cast<MemRefType>();
unsigned subRank = subType.getRank();
- // Properly size the array so that we can do random insertions
- // at the right indices.
- // We do that to populate the non-dropped sizes and strides in one go.
- results.resize_for_overwrite(subRank * 2 + 2);
-
- results[0] = newExtractStridedMetadata.getBaseBuffer();
- results[1] =
- getValueOrCreateConstantIndexOp(rewriter, origLoc, finalOffset);
// The sizes of the final type are defined directly by the input sizes of
// the subview.
@@ -139,24 +123,30 @@ struct ExtractStridedMetadataOpSubviewFolder
// replacing.
// Do the filtering here.
SmallVector<OpFoldResult> subSizes = subview.getMixedSizes();
- const unsigned sizeStartIdx = 2;
- const unsigned strideStartIdx = sizeStartIdx + subRank;
- unsigned insertedDims = 0;
llvm::SmallBitVector droppedDims = subview.getDroppedDims();
+
+ SmallVector<OpFoldResult> finalSizes;
+ finalSizes.reserve(subRank);
+
+ SmallVector<OpFoldResult> finalStrides;
+ finalStrides.reserve(subRank);
+
for (unsigned i = 0; i < sourceRank; ++i) {
if (droppedDims.test(i))
continue;
- results[sizeStartIdx + insertedDims] =
- getValueOrCreateConstantIndexOp(rewriter, origLoc, subSizes[i]);
- results[strideStartIdx + insertedDims] =
- getValueOrCreateConstantIndexOp(rewriter, origLoc, strides[i]);
- ++insertedDims;
+ finalSizes.push_back(subSizes[i]);
+ finalStrides.push_back(strides[i]);
}
- assert(insertedDims == subRank &&
+ assert(finalSizes.size() == subRank &&
"Should have populated all the values at this point");
- rewriter.replaceOp(op, results);
+ auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
+ origLoc, subType, newExtractStridedMetadata.getBaseBuffer(),
+ finalOffset,
+ /*sizes=*/finalSizes,
+ /*strides=*/finalStrides);
+ rewriter.replaceOp(subview, memrefDesc.getResult());
return success();
}
};
@@ -756,7 +746,7 @@ class ExtractStridedMetadataOpExtractStridedMetadataFolder
void memref::populateSimplifyExtractStridedMetadataOpPatterns(
RewritePatternSet &patterns) {
patterns
- .add<ExtractStridedMetadataOpSubviewFolder,
+ .add<SubviewFolder,
ExtractStridedMetadataOpReshapeFolder<
memref::ExpandShapeOp, getExpandedSizes, getExpandedStrides>,
ExtractStridedMetadataOpReshapeFolder<
diff --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
index b6661ee1b5dd5..3f312df10e214 100644
--- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
@@ -21,6 +21,57 @@ func.func @extract_strided_metadata_constants(%base: memref<5x4xf32, strided<[4,
// -----
+// Check that we simplify subview(src) into:
+// base, offset, sizes, strides xtract_strided_metadata src
+// final_sizes = subSizes
+// final_strides = <some math> strides
+// final_offset = <some math> offset
+// reinterpret_cast base to final_offset, final_sizes, final_ strides
+//
+// Orig strides: [s0, s1, s2]
+// Sub strides: [subS0, subS1, subS2]
+// => New strides: [s0 * subS0, s1 * subS1, s2 * subS2]
+// ==> 1 affine map (used for each stride) with two values.
+//
+// Orig offset: origOff
+// Sub offsets: [subO0, subO1, subO2]
+// => Final offset: s0 * * subO0 + ... + s2 * * subO2 + origOff
+// ==> 1 affine map with (rank * 2 + 1) symbols
+//
+// CHECK-DAG: #[[$STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
+// CHECK-DAG: #[[$OFFSET_MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s0 + s1 * s2 + s3 * s4 + s5 * s6)>
+// CHECK-LABEL: func @simplify_subview_all_dynamic
+// CHECK-SAME: (%[[ARG:.*]]: memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>, %[[DYN_OFFSET0:.*]]: index, %[[DYN_OFFSET1:.*]]: index, %[[DYN_OFFSET2:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_SIZE2:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index, %[[DYN_STRIDE2:.*]]: index)
+//
+// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]]
+//
+// CHECK-DAG: %[[FINAL_STRIDE0:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE0]], %[[STRIDES]]#0]
+// CHECK-DAG: %[[FINAL_STRIDE1:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE1]], %[[STRIDES]]#1]
+// CHECK-DAG: %[[FINAL_STRIDE2:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE2]], %[[STRIDES]]#2]
+//
+// CHECK-DAG: %[[FINAL_OFFSET:.*]] = affine.apply #[[$OFFSET_MAP]]()[%[[OFFSET]], %[[DYN_OFFSET0]], %[[STRIDES]]#0, %[[DYN_OFFSET1]], %[[STRIDES]]#1, %[[DYN_OFFSET2]], %[[STRIDES]]#2]
+//
+// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[FINAL_OFFSET]]], sizes: [%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]]], strides: [%[[FINAL_STRIDE0]], %[[FINAL_STRIDE1]], %[[FINAL_STRIDE2]]]
+//
+// CHECK: return %[[RES]]
+func.func @simplify_subview_all_dynamic(
+ %base: memref<?x?x?xf32, strided<[?,?,?], offset:?>>,
+ %offset0: index, %offset1: index, %offset2: index,
+ %size0: index, %size1: index, %size2: index,
+ %stride0: index, %stride1: index, %stride2: index)
+ -> memref<?x?x?xf32, strided<[?,?,?], offset:?>> {
+
+ %subview = memref.subview %base[%offset0, %offset1, %offset2]
+ [%size0, %size1, %size2]
+ [%stride0, %stride1, %stride2] :
+ memref<?x?x?xf32, strided<[?,?,?], offset: ?>> to
+ memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+
+ return %subview : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+}
+
+// -----
+
// Check that we simplify extract_strided_metadata of subview to
// base_buf, base_offset, base_sizes, base_strides = extract_strided_metadata
// strides = base_stride_i * subview_stride_i
More information about the Mlir-commits
mailing list