[Mlir-commits] [mlir] d665448 - [mlir][MemRef] Change the anchor point of a reshapeLikeOp pattern
Quentin Colombet
llvmlistbot at llvm.org
Mon Nov 14 11:06:05 PST 2022
Author: Quentin Colombet
Date: 2022-11-14T18:56:35Z
New Revision: d665448a7f3ad88fac6c852703cf1f7baeb9200b
URL: https://github.com/llvm/llvm-project/commit/d665448a7f3ad88fac6c852703cf1f7baeb9200b
DIFF: https://github.com/llvm/llvm-project/commit/d665448a7f3ad88fac6c852703cf1f7baeb9200b.diff
LOG: [mlir][MemRef] Change the anchor point of a reshapeLikeOp pattern
Essentially, this patches changes the anchor point of the
`extract_strided_metadata(reshapeLikeOp)` pattern from
`extract_strided_metadata` to `reshapeLikeOp`.
In details, this means that instead of replacing:
```
base, offset, sizes, strides =
extract_strided_metadata(reshapeLikeOp(src))
```
With
```
base, offset = extract_strided_metadata(src)
sizes = <some math>
strides = <some math>
```
We replace only the reshapeLikeOp part and connect it back with a
reinterpret_cast:
```
val = reshapeLikeOp(src)
```
=>
```
base, offset, ... = extract_strided_metadata(src)
sizes = <some math>
strides = <some math>
val = reinterpret_cast base, offset, sizes, strides
Differential Revision: https://reviews.llvm.org/D136386
Added:
Modified:
mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
index 3414538b71e33..bbf83575bd6c0 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
@@ -455,20 +455,15 @@ template <typename ReassociativeReshapeLikeOp,
ReassociativeReshapeLikeOp, OpBuilder &,
ArrayRef<OpFoldResult> /*origSizes*/,
ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)>
-struct ExtractStridedMetadataOpReshapeFolder
- : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
+struct ReshapeFolder : public OpRewritePattern<ReassociativeReshapeLikeOp> {
public:
- using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
+ using OpRewritePattern<ReassociativeReshapeLikeOp>::OpRewritePattern;
- LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
+ LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape,
PatternRewriter &rewriter) const override {
- auto reshape = op.getSource().getDefiningOp<ReassociativeReshapeLikeOp>();
- if (!reshape)
- return failure();
-
// Build a plain extract_strided_metadata(memref) from
// extract_strided_metadata(reassociative_reshape_like(memref)).
- Location origLoc = op.getLoc();
+ Location origLoc = reshape.getLoc();
Value source = reshape.getSrc();
auto sourceType = source.getType().cast<MemRefType>();
unsigned sourceRank = sourceType.getRank();
@@ -487,27 +482,26 @@ struct ExtractStridedMetadataOpReshapeFolder
MemRefType reshapeType = reshape.getResultType();
unsigned reshapeRank = reshapeType.getRank();
- // The result value will start with the base_buffer and offset.
- unsigned baseIdxInResult = 2;
- SmallVector<OpFoldResult> results(baseIdxInResult + reshapeRank * 2);
- results[0] = newExtractStridedMetadata.getBaseBuffer();
- results[1] = ShapedType::isDynamicStrideOrOffset(offset)
- ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
- : rewriter.getIndexAttr(offset);
+ OpFoldResult offsetOfr =
+ ShapedType::isDynamicStrideOrOffset(offset)
+ ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
+ : rewriter.getIndexAttr(offset);
// Get the special case of 0-D out of the way.
if (sourceRank == 0) {
- Value constantOne = getValueOrCreateConstantIndexOp(
- rewriter, origLoc, rewriter.getIndexAttr(1));
- SmallVector<Value> resultValues(baseIdxInResult + reshapeRank * 2,
- constantOne);
- for (unsigned i = 0; i < baseIdxInResult; ++i)
- resultValues[i] =
- getValueOrCreateConstantIndexOp(rewriter, origLoc, results[i]);
- rewriter.replaceOp(op, resultValues);
+ SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1));
+ auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
+ origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
+ offsetOfr, /*sizes=*/ones, /*strides=*/ones);
+ rewriter.replaceOp(reshape, memrefDesc.getResult());
return success();
}
+ SmallVector<OpFoldResult> finalSizes;
+ finalSizes.reserve(reshapeRank);
+ SmallVector<OpFoldResult> finalStrides;
+ finalStrides.reserve(reshapeRank);
+
// Compute the reshaped strides and sizes from the base strides and sizes.
SmallVector<OpFoldResult> origSizes =
getAsOpFoldResult(newExtractStridedMetadata.getSizes());
@@ -521,21 +515,20 @@ struct ExtractStridedMetadataOpReshapeFolder
reshape, rewriter, origSizes, origStrides, /*groupId=*/idx);
unsigned groupSize = reshapedSizes.size();
- const unsigned sizeStartIdx = baseIdxInResult;
- const unsigned strideStartIdx = sizeStartIdx + reshapeRank;
for (unsigned i = 0; i < groupSize; ++i) {
- results[sizeStartIdx + i] = reshapedSizes[i];
- results[strideStartIdx + i] = reshapedStrides[i];
+ finalSizes.push_back(reshapedSizes[i]);
+ finalStrides.push_back(reshapedStrides[i]);
}
- baseIdxInResult += groupSize;
}
assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
(isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
"We should have visited all the input dimensions");
- assert(baseIdxInResult == reshapeRank + 2 &&
+ assert(finalSizes.size() == reshapeRank &&
"We should have populated all the values");
- rewriter.replaceOp(
- op, getValueOrCreateConstantIndexOp(rewriter, origLoc, results));
+ auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
+ origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
+ offsetOfr, finalSizes, finalStrides);
+ rewriter.replaceOp(reshape, memrefDesc.getResult());
return success();
}
};
@@ -745,18 +738,17 @@ class ExtractStridedMetadataOpExtractStridedMetadataFolder
void memref::populateSimplifyExtractStridedMetadataOpPatterns(
RewritePatternSet &patterns) {
- patterns
- .add<SubviewFolder,
- ExtractStridedMetadataOpReshapeFolder<
- memref::ExpandShapeOp, getExpandedSizes, getExpandedStrides>,
- ExtractStridedMetadataOpReshapeFolder<
- memref::CollapseShapeOp, getCollapsedSize, getCollapsedStride>,
- ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
- ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
- RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
- ExtractStridedMetadataOpReinterpretCastFolder,
- ExtractStridedMetadataOpExtractStridedMetadataFolder>(
- patterns.getContext());
+ patterns.add<SubviewFolder,
+ ReshapeFolder<memref::ExpandShapeOp, getExpandedSizes,
+ getExpandedStrides>,
+ ReshapeFolder<memref::CollapseShapeOp, getCollapsedSize,
+ getCollapsedStride>,
+ ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
+ ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
+ RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
+ ExtractStridedMetadataOpReinterpretCastFolder,
+ ExtractStridedMetadataOpExtractStridedMetadataFolder>(
+ patterns.getContext());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
index 3f312df10e214..d7f2dfdb77b31 100644
--- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
@@ -352,6 +352,88 @@ func.func @extract_strided_metadata_of_subview_all_dynamic(
// -----
+// Check that we properly simplify expand_shape into:
+// reinterpret_cast(extract_strided_metadata) + <some math>
+//
+// Here we have:
+// For the group applying to dim0:
+// size 0 = baseSizes#0 / (all static sizes in that group)
+// = baseSizes#0 / (7 * 8 * 9)
+// = baseSizes#0 / 504
+// size 1 = 7
+// size 2 = 8
+// size 3 = 9
+// stride 0 = baseStrides#0 * 7 * 8 * 9
+// = baseStrides#0 * 504
+// stride 1 = baseStrides#0 * 8 * 9
+// = baseStrides#0 * 72
+// stride 2 = baseStrides#0 * 9
+// stride 3 = baseStrides#0
+//
+// For the group applying to dim1:
+// size 4 = 10
+// size 5 = 2
+// size 6 = baseSizes#1 / (all static sizes in that group)
+// = baseSizes#1 / (10 * 2 * 3)
+// = baseSizes#1 / 60
+// size 7 = 3
+// stride 4 = baseStrides#1 * size 5 * size 6 * size 7
+// = baseStrides#1 * 2 * (baseSizes#1 / 60) * 3
+// = baseStrides#1 * (baseSizes#1 / 60) * 6
+// and since we know that baseSizes#1 is a multiple of 60:
+// = baseStrides#1 * (baseSizes#1 / 10)
+// stride 5 = baseStrides#1 * size 6 * size 7
+// = baseStrides#1 * (baseSizes#1 / 60) * 3
+// = baseStrides#1 * (baseSizes#1 / 20)
+// stride 6 = baseStrides#1 * size 7
+// = baseStrides#1 * 3
+// stride 7 = baseStrides#1
+//
+// Base and offset are unchanged.
+//
+// CHECK-DAG: #[[$DIM0_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 504)>
+// CHECK-DAG: #[[$DIM6_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 60)>
+//
+// CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 504)>
+// CHECK-DAG: #[[$DIM1_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 72)>
+// CHECK-DAG: #[[$DIM2_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 9)>
+// CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 10) * s1)>
+// CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 20) * s1)>
+// CHECK-DAG: #[[$DIM6_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 3)>
+// CHECK-LABEL: func @simplify_expand_shape
+// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32,
+//
+// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> -> memref<f32>, index, index, index, index, index
+//
+// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$DIM0_SIZE_MAP]]()[%[[SIZES]]#0]
+// CHECK-DAG: %[[DYN_SIZE6:.*]] = affine.apply #[[$DIM6_SIZE_MAP]]()[%[[SIZES]]#1]
+// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[STRIDES]]#0]
+// CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #[[$DIM1_STRIDE_MAP]]()[%[[STRIDES]]#0]
+// CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #[[$DIM2_STRIDE_MAP]]()[%[[STRIDES]]#0]
+// CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
+// CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
+// CHECK-DAG: %[[DYN_STRIDE6:.*]] = affine.apply #[[$DIM6_STRIDE_MAP]]()[%[[STRIDES]]#1]
+//
+// CHECK-DAG: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [%[[DYN_SIZE0]], 7, 8, 9, 10, 2, %[[DYN_SIZE6]], 3], strides: [%[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1]
+//
+// CHECK: return %[[REINTERPRET_CAST]]
+func.func @simplify_expand_shape(
+ %base: memref<?x?xf32, strided<[?,?], offset:?>>,
+ %offset0: index, %offset1: index, %offset2: index,
+ %size0: index, %size1: index, %size2: index,
+ %stride0: index, %stride1: index, %stride2: index)
+ -> memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>> {
+
+ %subview = memref.expand_shape %base[[0, 1, 2, 3],[4, 5, 6, 7]] :
+ memref<?x?xf32, strided<[?,?], offset: ?>> into
+ memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
+
+ return %subview :
+ memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
+}
+
+// -----
+
// Check that we properly simplify extract_strided_metadata of expand_shape
// into:
// baseBuffer, baseOffset, baseSizes, baseStrides =
@@ -815,6 +897,43 @@ func.func @extract_aligned_pointer_as_index(%arg0: memref<f32>) -> index {
// -----
+// Check that we simplify collapse_shape into
+// reinterpret_cast(extract_strided_metadata) + <some math>
+//
+// We transform: ?x?x4x?x6x7xi32 to [0][1,2,3][4,5]
+// Size 0 = origSize0
+// Size 1 = origSize1 * origSize2 * origSize3
+// = origSize1 * 4 * origSize3
+// Size 2 = origSize4 * origSize5
+// = 6 * 7
+// = 42
+// Stride 0 = origStride0
+// Stride 1 = origStride3 (orig stride of the inner most dimension)
+// = 42
+// Stride 2 = origStride5
+// = 1
+//
+// CHECK-DAG: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
+// CHECK-LABEL: func @simplify_collapse(
+// CHECK-SAME: %[[ARG:.*]]: memref<?x?x4x?x6x7xi32>)
+//
+// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref<?x?x4x?x6x7xi32>
+//
+// CHECK: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3]
+//
+// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [%[[SIZES]]#0, %[[DYN_SIZE1]], 42], strides: [%[[STRIDES]]#0, 42, 1]
+func.func @simplify_collapse(%arg : memref<?x?x4x?x6x7xi32>)
+ -> memref<?x?x42xi32> {
+
+ %collapsed_view = memref.collapse_shape %arg [[0], [1, 2, 3], [4, 5]] :
+ memref<?x?x4x?x6x7xi32> into memref<?x?x42xi32>
+
+ return %collapsed_view : memref<?x?x42xi32>
+
+}
+
+// -----
+
// Check that we simplify extract_strided_metadata of collapse_shape.
//
// We transform: ?x?x4x?x6x7xi32 to [0][1,2,3][4,5]
More information about the Mlir-commits
mailing list