[Mlir-commits] [mlir] 450ac01 - [mlir][MemRef] Add ExtractStridedMetadataOpCollapseShapeFolder (#89954)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 26 07:20:32 PDT 2024
Author: Diego Caballero
Date: 2024-04-26T16:20:24+02:00
New Revision: 450ac01bb9f4ee4f9e05eb6846f256bf2079a436
URL: https://github.com/llvm/llvm-project/commit/450ac01bb9f4ee4f9e05eb6846f256bf2079a436
DIFF: https://github.com/llvm/llvm-project/commit/450ac01bb9f4ee4f9e05eb6846f256bf2079a436.diff
LOG: [mlir][MemRef] Add ExtractStridedMetadataOpCollapseShapeFolder (#89954)
This PR adds a new pattern to the set of patterns used to resolve the offset, sizes and
stride of a memref. Similar to `ExtractStridedMetadataOpSubviewFolder`, the new
pattern resolves strided_metadata(collapse_shape) directly, without introduce a
reshape_cast op.
Added:
Modified:
mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 96eb7cfd2db690..585c5b73814219 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -550,6 +550,89 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
return {makeComposedFoldedAffineMin(builder, collapseShape.getLoc(), minMap,
groupStrides)};
}
+
+/// From `reshape_like(memref, subSizes, subStrides))` compute
+///
+/// \verbatim
+/// baseBuffer, baseOffset, baseSizes, baseStrides =
+/// extract_strided_metadata(memref)
+/// strides#i = baseStrides#i * subStrides#i
+/// sizes = subSizes
+/// \endverbatim
+///
+/// and return {baseBuffer, baseOffset, sizes, strides}
+template <typename ReassociativeReshapeLikeOp>
+static FailureOr<StridedMetadata> resolveReshapeStridedMetadata(
+ RewriterBase &rewriter, ReassociativeReshapeLikeOp reshape,
+ function_ref<SmallVector<OpFoldResult>(
+ ReassociativeReshapeLikeOp, OpBuilder &,
+ ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/)>
+ getReshapedSizes,
+ function_ref<SmallVector<OpFoldResult>(
+ ReassociativeReshapeLikeOp, OpBuilder &,
+ ArrayRef<OpFoldResult> /*origSizes*/,
+ ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)>
+ getReshapedStrides) {
+ // Build a plain extract_strided_metadata(memref) from
+ // extract_strided_metadata(reassociative_reshape_like(memref)).
+ Location origLoc = reshape.getLoc();
+ Value source = reshape.getSrc();
+ auto sourceType = cast<MemRefType>(source.getType());
+ unsigned sourceRank = sourceType.getRank();
+
+ auto newExtractStridedMetadata =
+ rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
+
+ // Collect statically known information.
+ auto [strides, offset] = getStridesAndOffset(sourceType);
+ MemRefType reshapeType = reshape.getResultType();
+ unsigned reshapeRank = reshapeType.getRank();
+
+ OpFoldResult offsetOfr =
+ ShapedType::isDynamic(offset)
+ ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
+ : rewriter.getIndexAttr(offset);
+
+ // Get the special case of 0-D out of the way.
+ if (sourceRank == 0) {
+ SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1));
+ return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
+ /*sizes=*/ones, /*strides=*/ones};
+ }
+
+ SmallVector<OpFoldResult> finalSizes;
+ finalSizes.reserve(reshapeRank);
+ SmallVector<OpFoldResult> finalStrides;
+ finalStrides.reserve(reshapeRank);
+
+ // Compute the reshaped strides and sizes from the base strides and sizes.
+ SmallVector<OpFoldResult> origSizes =
+ getAsOpFoldResult(newExtractStridedMetadata.getSizes());
+ SmallVector<OpFoldResult> origStrides =
+ getAsOpFoldResult(newExtractStridedMetadata.getStrides());
+ unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
+ for (; idx != endIdx; ++idx) {
+ SmallVector<OpFoldResult> reshapedSizes =
+ getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx);
+ SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides(
+ reshape, rewriter, origSizes, origStrides, /*groupId=*/idx);
+
+ unsigned groupSize = reshapedSizes.size();
+ for (unsigned i = 0; i < groupSize; ++i) {
+ finalSizes.push_back(reshapedSizes[i]);
+ finalStrides.push_back(reshapedStrides[i]);
+ }
+ }
+ assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
+ (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
+ "We should have visited all the input dimensions");
+ assert(finalSizes.size() == reshapeRank &&
+ "We should have populated all the values");
+
+ return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
+ finalSizes, finalStrides};
+}
+
/// Replace `baseBuffer, offset, sizes, strides =
/// extract_strided_metadata(reshapeLike(memref))`
/// With
@@ -580,68 +663,65 @@ struct ReshapeFolder : public OpRewritePattern<ReassociativeReshapeLikeOp> {
LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape,
PatternRewriter &rewriter) const override {
- // Build a plain extract_strided_metadata(memref) from
- // extract_strided_metadata(reassociative_reshape_like(memref)).
- Location origLoc = reshape.getLoc();
- Value source = reshape.getSrc();
- auto sourceType = cast<MemRefType>(source.getType());
- unsigned sourceRank = sourceType.getRank();
-
- auto newExtractStridedMetadata =
- rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
-
- // Collect statically known information.
- auto [strides, offset] = getStridesAndOffset(sourceType);
- MemRefType reshapeType = reshape.getResultType();
- unsigned reshapeRank = reshapeType.getRank();
-
- OpFoldResult offsetOfr =
- ShapedType::isDynamic(offset)
- ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
- : rewriter.getIndexAttr(offset);
-
- // Get the special case of 0-D out of the way.
- if (sourceRank == 0) {
- SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1));
- auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
- origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
- offsetOfr, /*sizes=*/ones, /*strides=*/ones);
- rewriter.replaceOp(reshape, memrefDesc.getResult());
- return success();
+ FailureOr<StridedMetadata> stridedMetadata =
+ resolveReshapeStridedMetadata<ReassociativeReshapeLikeOp>(
+ rewriter, reshape, getReshapedSizes, getReshapedStrides);
+ if (failed(stridedMetadata)) {
+ return rewriter.notifyMatchFailure(reshape,
+ "failed to resolve reshape metadata");
}
- SmallVector<OpFoldResult> finalSizes;
- finalSizes.reserve(reshapeRank);
- SmallVector<OpFoldResult> finalStrides;
- finalStrides.reserve(reshapeRank);
-
- // Compute the reshaped strides and sizes from the base strides and sizes.
- SmallVector<OpFoldResult> origSizes =
- getAsOpFoldResult(newExtractStridedMetadata.getSizes());
- SmallVector<OpFoldResult> origStrides =
- getAsOpFoldResult(newExtractStridedMetadata.getStrides());
- unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
- for (; idx != endIdx; ++idx) {
- SmallVector<OpFoldResult> reshapedSizes =
- getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx);
- SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides(
- reshape, rewriter, origSizes, origStrides, /*groupId=*/idx);
-
- unsigned groupSize = reshapedSizes.size();
- for (unsigned i = 0; i < groupSize; ++i) {
- finalSizes.push_back(reshapedSizes[i]);
- finalStrides.push_back(reshapedStrides[i]);
- }
+ rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+ reshape, reshape.getType(), stridedMetadata->basePtr,
+ stridedMetadata->offset, stridedMetadata->sizes,
+ stridedMetadata->strides);
+ return success();
+ }
+};
+
+/// Pattern to replace `extract_strided_metadata(collapse_shape)`
+/// With
+///
+/// \verbatim
+/// baseBuffer, baseOffset, baseSizes, baseStrides =
+/// extract_strided_metadata(memref)
+/// strides#i = baseStrides#i * subSizes#i
+/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
+/// sizes = subSizes
+/// \verbatim
+///
+/// with `baseBuffer`, `offset`, `sizes` and `strides` being
+/// the replacements for the original `extract_strided_metadata`.
+struct ExtractStridedMetadataOpCollapseShapeFolder
+ : OpRewritePattern<memref::ExtractStridedMetadataOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
+ PatternRewriter &rewriter) const override {
+ auto collapseShapeOp =
+ op.getSource().getDefiningOp<memref::CollapseShapeOp>();
+ if (!collapseShapeOp)
+ return failure();
+
+ FailureOr<StridedMetadata> stridedMetadata =
+ resolveReshapeStridedMetadata<memref::CollapseShapeOp>(
+ rewriter, collapseShapeOp, getCollapsedSize, getCollapsedStride);
+ if (failed(stridedMetadata)) {
+ return rewriter.notifyMatchFailure(
+ op,
+ "failed to resolve metadata in terms of source collapse_shape op");
}
- assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
- (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
- "We should have visited all the input dimensions");
- assert(finalSizes.size() == reshapeRank &&
- "We should have populated all the values");
- auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
- origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
- offsetOfr, finalSizes, finalStrides);
- rewriter.replaceOp(reshape, memrefDesc.getResult());
+
+ Location loc = collapseShapeOp.getLoc();
+ SmallVector<Value> results;
+ results.push_back(stridedMetadata->basePtr);
+ results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
+ stridedMetadata->offset));
+ results.append(
+ getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
+ results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
+ stridedMetadata->strides));
+ rewriter.replaceOp(op, results);
return success();
}
};
@@ -1018,9 +1098,11 @@ void memref::populateExpandStridedMetadataPatterns(
getCollapsedStride>,
ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
+ ExtractStridedMetadataOpCollapseShapeFolder,
ExtractStridedMetadataOpGetGlobalFolder,
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
ExtractStridedMetadataOpReinterpretCastFolder,
+ ExtractStridedMetadataOpSubviewFolder,
ExtractStridedMetadataOpCastFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
patterns.getContext());
@@ -1030,6 +1112,7 @@ void memref::populateResolveExtractStridedMetadataPatterns(
RewritePatternSet &patterns) {
patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
+ ExtractStridedMetadataOpCollapseShapeFolder,
ExtractStridedMetadataOpGetGlobalFolder,
ExtractStridedMetadataOpSubviewFolder,
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index 28b70043005940..0705b30ca45d86 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -1513,4 +1513,26 @@ func.func @zero_sized_memred(%arg0: f32) -> (memref<f16, 3>, index,index,index)
%sizes, %strides :
memref<f16,3>, index,
index, index
-}
\ No newline at end of file
+}
+
+// -----
+
+func.func @extract_strided_metadata_of_collapse_shape(%base: memref<5x4xf32>)
+ -> (memref<f32>, index, index, index) {
+
+ %collapse = memref.collapse_shape %base[[0, 1]] :
+ memref<5x4xf32> into memref<20xf32>
+
+ %base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %collapse :
+ memref<20xf32> -> memref<f32>, index, index, index
+
+ return %base_buffer, %offset, %size, %stride :
+ memref<f32>, index, index, index
+}
+
+// CHECK-LABEL: func @extract_strided_metadata_of_collapse_shape
+// CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[SIZE:.*]] = arith.constant 20 : index
+// CHECK-DAG: %[[STEP:.*]] = arith.constant 1 : index
+// CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata
+// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZE]], %[[STEP]] : memref<f32>, index, index, index
More information about the Mlir-commits
mailing list