[Mlir-commits] [mlir] d831568 - [mlir][MemRef] Simplify extract_strided_metadata(collapse_shape)
Quentin Colombet
llvmlistbot at llvm.org
Fri Sep 30 10:08:41 PDT 2022
Author: Quentin Colombet
Date: 2022-09-30T16:54:56Z
New Revision: d8315681714222e83e32beba374a5ff976d90059
URL: https://github.com/llvm/llvm-project/commit/d8315681714222e83e32beba374a5ff976d90059
DIFF: https://github.com/llvm/llvm-project/commit/d8315681714222e83e32beba374a5ff976d90059.diff
LOG: [mlir][MemRef] Simplify extract_strided_metadata(collapse_shape)
The new pattern gets rid of the collapse_shape operation while
materializing its effects on the sizes, and the strides of the base
object.
In other words, this simplification replaces:
```
baseBuffer, offset, sizes, strides =
extract_strided_metadata(collapse_shape(memref))
```
With
```
baseBuffer, offset, baseSizes, baseStrides =
extract_strided_metadata(memref)
for reassDim in {0 .. collapseRank - 1}
sizes#reassDim = product(baseSizes#i for i in group[reassDim])
strides#reassDim = baseStrides[group[reassDim].back()]
```
Note: baseBuffer and offset are unaffected by the collapse_shape
operation.
Differential Revision: https://reviews.llvm.org/D134826
Added:
Modified:
mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
index 462950fd5268c..5a14578fa112a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
@@ -173,6 +173,11 @@ struct ExtractStridedMetadataOpSubviewFolder
/// \p origSizes hold the sizes of the source shape as values.
/// This is used to compute the new sizes in cases of dynamic shapes.
///
+/// sizes#i =
+/// baseSizes#groupId / product(expandShapeSizes#j,
+/// for j in group excluding reassIdx#i)
+/// Where reassIdx#i is the reassociation index at index i in \p groupId.
+///
/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
///
/// TODO: Move this utility function directly within ExpandShapeOp. For now,
@@ -225,6 +230,18 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
/// This is used to compute the strides in cases of dynamic shapes and/or
/// dynamic stride for this reassociation group.
///
+/// strides#i =
+/// origStrides#reassDim * product(expandShapeSizes#j, for j in
+/// reassIdx#i+1..reassIdx#i+group.size-1)
+///
+/// Where reassIdx#i is the reassociation index for at index i in \p groupId
+/// and expandShapeSizes#j is either:
+/// - The constant size at dimension j, derived directly from the result type of
+/// the expand_shape op, or
+/// - An affine expression: baseSizes#reassDim / product of all constant sizes
+/// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
+/// element.)
+///
/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
///
/// TODO: Move this utility function directly within ExpandShapeOp. For now,
@@ -315,49 +332,162 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
return expandedStrides;
}
+/// Produce an OpFoldResult object with \p builder at \p loc representing
+/// `prod(valueOrConstant#i, for i in {indices})`,
+/// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false,
+/// values[i] otherwise.
+///
+/// \pre for all index in indices: index < values.size()
+/// \pre for all index in indices: index < maybeConstants.size()
+static OpFoldResult
+getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
+ ArrayRef<int64_t> maybeConstants,
+ ArrayRef<OpFoldResult> values,
+ llvm::function_ref<bool(int64_t)> isDynamic) {
+ AffineExpr productOfValues = builder.getAffineConstantExpr(1);
+ SmallVector<OpFoldResult> inputValues;
+ unsigned numberOfSymbols = 0;
+ unsigned groupSize = indices.size();
+ for (unsigned i = 0; i < groupSize; ++i) {
+ productOfValues =
+ productOfValues * builder.getAffineSymbolExpr(numberOfSymbols++);
+ unsigned srcIdx = indices[i];
+ int64_t maybeConstant = maybeConstants[srcIdx];
+
+ inputValues.push_back(isDynamic(maybeConstant)
+ ? values[srcIdx]
+ : builder.getIndexAttr(maybeConstant));
+ }
+
+ return makeComposedFoldedAffineApply(builder, loc, productOfValues,
+ inputValues);
+}
+
+/// Compute the collapsed size of the given \p collpaseShape for the
+/// \p groupId-th reassociation group.
+/// \p origSizes hold the sizes of the source shape as values.
+/// This is used to compute the new sizes in cases of dynamic shapes.
+///
+/// Conceptually this helper function computes:
+/// `prod(origSizes#i, for i in {ressociationGroup[groupId]})`.
+///
+/// \post result.size() == 1, in other words, each group collapse to one
+/// dimension.
+///
+/// TODO: Move this utility function directly within CollapseShapeOp. For now,
+/// this is not possible because this function uses the Affine dialect and the
+/// MemRef dialect cannot depend on the Affine dialect.
+static SmallVector<OpFoldResult>
+getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
+ SmallVector<OpFoldResult> collapsedSize;
+
+ MemRefType collapseShapeType = collapseShape.getResultType();
+
+ uint64_t size = collapseShapeType.getDimSize(groupId);
+ if (!ShapedType::isDynamic(size)) {
+ collapsedSize.push_back(builder.getIndexAttr(size));
+ return collapsedSize;
+ }
+
+ // We are dealing with a dynamic size.
+ // Build the affine expr of the product of the original sizes involved in that
+ // group.
+ Value source = collapseShape.getSrc();
+ auto sourceType = source.getType().cast<MemRefType>();
+
+ SmallVector<int64_t, 2> reassocGroup =
+ collapseShape.getReassociationIndices()[groupId];
+
+ collapsedSize.push_back(getProductOfValues(
+ reassocGroup, builder, collapseShape.getLoc(), sourceType.getShape(),
+ origSizes, ShapedType::isDynamic));
+
+ return collapsedSize;
+}
+
+/// Compute the collapsed stride of the given \p collpaseShape for the
+/// \p groupId-th reassociation group.
+/// \p origStrides and \p origSizes hold respectively the strides and sizes
+/// of the source shape as values.
+/// This is used to compute the strides in cases of dynamic shapes and/or
+/// dynamic stride for this reassociation group.
+///
+/// Conceptually this helper function returns the stride of the inner most
+/// dimension of that group in the original shape.
+///
+/// \post result.size() == 1, in other words, each group collapse to one
+/// dimension.
+static SmallVector<OpFoldResult>
+getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes,
+ ArrayRef<OpFoldResult> origStrides, unsigned groupId) {
+ SmallVector<int64_t, 2> reassocGroup =
+ collapseShape.getReassociationIndices()[groupId];
+ assert(!reassocGroup.empty() &&
+ "Reassociation group should have at least one dimension");
+
+ Value source = collapseShape.getSrc();
+ auto sourceType = source.getType().cast<MemRefType>();
+
+ SmallVector<int64_t> strides;
+ int64_t offset;
+ bool hasKnownStridesAndOffset =
+ succeeded(getStridesAndOffset(sourceType, strides, offset));
+ (void)hasKnownStridesAndOffset;
+ assert(hasKnownStridesAndOffset &&
+ "getStridesAndOffset must work on valid collapse_shape");
+
+ SmallVector<OpFoldResult> collapsedStride;
+ int64_t innerMostDimForGroup = reassocGroup.back();
+ int64_t innerMostStrideForGroup = strides[innerMostDimForGroup];
+ collapsedStride.push_back(
+ ShapedType::isDynamicStrideOrOffset(innerMostStrideForGroup)
+ ? origStrides[innerMostDimForGroup]
+ : builder.getIndexAttr(innerMostStrideForGroup));
+
+ return collapsedStride;
+}
/// Replace `baseBuffer, offset, sizes, strides =
-/// extract_strided_metadata(expand_shape(memref))`
+/// extract_strided_metadata(reshapeLike(memref))`
/// With
///
/// \verbatim
/// baseBuffer, offset, baseSizes, baseStrides =
/// extract_strided_metadata(memref)
-/// sizes#reassIdx =
-/// baseSizes#reassDim / product(expandShapeSizes#j,
-/// for j in group excluding reassIdx)
-/// strides#reassIdx =
-/// baseStrides#reassDim * product(expandShapeSizes#j, for j in
-/// reassIdx+1..reassIdx+group.size-1)
+/// sizes = getReshapedSizes(reshapeLike)
+/// strides = getReshapedStrides(reshapeLike)
/// \endverbatim
///
-/// Where reassIdx is a reassociation index for the group at reassDim
-/// and expandShapeSizes#j is either:
-/// - The constant size at dimension j, derived directly from the result type of
-/// the expand_shape op, or
-/// - An affine expression: baseSizes#reassDim / product of all constant sizes
-/// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
-/// element.)
///
/// Notice that `baseBuffer` and `offset` are unchanged.
///
/// In other words, get rid of the expand_shape in that expression and
/// materialize its effects on the sizes and the strides using affine apply.
-struct ExtractStridedMetadataOpExpandShapeFolder
+template <typename ReassociativeReshapeLikeOp,
+ SmallVector<OpFoldResult> (*getReshapedSizes)(
+ ReassociativeReshapeLikeOp, OpBuilder &,
+ ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/),
+ SmallVector<OpFoldResult> (*getReshapedStrides)(
+ ReassociativeReshapeLikeOp, OpBuilder &,
+ ArrayRef<OpFoldResult> /*origSizes*/,
+ ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)>
+struct ExtractStridedMetadataOpReshapeFolder
: public OpRewritePattern<memref::ExtractStridedMetadataOp> {
public:
using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
PatternRewriter &rewriter) const override {
- auto expandShape = op.getSource().getDefiningOp<memref::ExpandShapeOp>();
- if (!expandShape)
+ auto reshape = op.getSource().getDefiningOp<ReassociativeReshapeLikeOp>();
+ if (!reshape)
return failure();
// Build a plain extract_strided_metadata(memref) from
- // extract_strided_metadata(expand_shape(memref)).
+ // extract_strided_metadata(reassociative_reshape_like(memref)).
Location origLoc = op.getLoc();
IndexType indexType = rewriter.getIndexType();
- Value source = expandShape.getSrc();
+ Value source = reshape.getSrc();
auto sourceType = source.getType().cast<MemRefType>();
unsigned sourceRank = sourceType.getRank();
SmallVector<Type> sizeStrideTypes(sourceRank, indexType);
@@ -374,13 +504,13 @@ struct ExtractStridedMetadataOpExpandShapeFolder
succeeded(getStridesAndOffset(sourceType, strides, offset));
(void)hasKnownStridesAndOffset;
assert(hasKnownStridesAndOffset &&
- "getStridesAndOffset must work on valid expand_shape");
- MemRefType expandShapeType = expandShape.getResultType();
- unsigned expandShapeRank = expandShapeType.getRank();
+ "getStridesAndOffset must work on valid reassociative_reshape_like");
+ MemRefType reshapeType = reshape.getResultType();
+ unsigned reshapeRank = reshapeType.getRank();
// The result value will start with the base_buffer and offset.
unsigned baseIdxInResult = 2;
- SmallVector<OpFoldResult> results(baseIdxInResult + expandShapeRank * 2);
+ SmallVector<OpFoldResult> results(baseIdxInResult + reshapeRank * 2);
results[0] = newExtractStridedMetadata.getBaseBuffer();
results[1] = ShapedType::isDynamicStrideOrOffset(offset)
? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
@@ -390,7 +520,7 @@ struct ExtractStridedMetadataOpExpandShapeFolder
if (sourceRank == 0) {
Value constantOne = getValueOrCreateConstantIndexOp(
rewriter, origLoc, rewriter.getIndexAttr(1));
- SmallVector<Value> resultValues(baseIdxInResult + expandShapeRank * 2,
+ SmallVector<Value> resultValues(baseIdxInResult + reshapeRank * 2,
constantOne);
for (unsigned i = 0; i < baseIdxInResult; ++i)
resultValues[i] =
@@ -399,30 +529,31 @@ struct ExtractStridedMetadataOpExpandShapeFolder
return success();
}
- // Compute the expanded strides and sizes from the base strides and sizes.
+ // Compute the reshaped strides and sizes from the base strides and sizes.
SmallVector<OpFoldResult> origSizes =
getAsOpFoldResult(newExtractStridedMetadata.getSizes());
SmallVector<OpFoldResult> origStrides =
getAsOpFoldResult(newExtractStridedMetadata.getStrides());
- unsigned idx = 0, endIdx = expandShape.getReassociationIndices().size();
+ unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
for (; idx != endIdx; ++idx) {
- SmallVector<OpFoldResult> expandedSizes =
- getExpandedSizes(expandShape, rewriter, origSizes, /*groupId=*/idx);
- SmallVector<OpFoldResult> expandedStrides = getExpandedStrides(
- expandShape, rewriter, origSizes, origStrides, /*groupId=*/idx);
+ SmallVector<OpFoldResult> reshapedSizes =
+ getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx);
+ SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides(
+ reshape, rewriter, origSizes, origStrides, /*groupId=*/idx);
- unsigned groupSize = expandShape.getReassociationIndices()[idx].size();
+ unsigned groupSize = reshapedSizes.size();
const unsigned sizeStartIdx = baseIdxInResult;
- const unsigned strideStartIdx = sizeStartIdx + expandShapeRank;
+ const unsigned strideStartIdx = sizeStartIdx + reshapeRank;
for (unsigned i = 0; i < groupSize; ++i) {
- results[sizeStartIdx + i] = expandedSizes[i];
- results[strideStartIdx + i] = expandedStrides[i];
+ results[sizeStartIdx + i] = reshapedSizes[i];
+ results[strideStartIdx + i] = reshapedStrides[i];
}
baseIdxInResult += groupSize;
}
- assert(idx == sourceRank &&
+ assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
+ (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
"We should have visited all the input dimensions");
- assert(baseIdxInResult == expandShapeRank + 2 &&
+ assert(baseIdxInResult == reshapeRank + 2 &&
"We should have populated all the values");
rewriter.replaceOp(
op, getValueOrCreateConstantIndexOp(rewriter, origLoc, results));
@@ -599,12 +730,17 @@ class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
void memref::populateSimplifyExtractStridedMetadataOpPatterns(
RewritePatternSet &patterns) {
- patterns.add<ExtractStridedMetadataOpSubviewFolder,
- ExtractStridedMetadataOpExpandShapeFolder, ForwardStaticMetadata,
- ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
- ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
- RewriteExtractAlignedPointerAsIndexOfViewLikeOp>(
- patterns.getContext());
+ patterns
+ .add<ExtractStridedMetadataOpSubviewFolder,
+ ExtractStridedMetadataOpReshapeFolder<
+ memref::ExpandShapeOp, getExpandedSizes, getExpandedStrides>,
+ ExtractStridedMetadataOpReshapeFolder<
+ memref::CollapseShapeOp, getCollapsedSize, getCollapsedStride>,
+ ForwardStaticMetadata,
+ ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
+ ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
+ RewriteExtractAlignedPointerAsIndexOfViewLikeOp>(
+ patterns.getContext());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
index 2862ef96a53c8..616b835842910 100644
--- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
@@ -762,3 +762,82 @@ func.func @extract_aligned_pointer_as_index(%arg0: memref<f32>) -> index {
%r = memref.extract_aligned_pointer_as_index %arg0: memref<f32> -> index
return %r : index
}
+
+// -----
+
+// Check that we simplify extract_strided_metadata of collapse_shape.
+//
+// We transform: ?x?x4x?x6x7xi32 to [0][1,2,3][4,5]
+// Size 0 = origSize0
+// Size 1 = origSize1 * origSize2 * origSize3
+// = origSize1 * 4 * origSize3
+// Size 2 = origSize4 * origSize5
+// = 6 * 7
+// = 42
+// Stride 0 = origStride0
+// Stride 1 = origStride3 (orig stride of the inner most dimension)
+// = 42
+// Stride 2 = origStride5
+// = 1
+//
+// CHECK-DAG: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
+// CHECK-LABEL: func @extract_strided_metadata_of_collapse(
+// CHECK-SAME: %[[ARG:.*]]: memref<?x?x4x?x6x7xi32>)
+//
+// CHECK-DAG: %[[C42:.*]] = arith.constant 42 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//
+// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref<?x?x4x?x6x7xi32>
+//
+// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3]
+//
+// CHECK: return %[[BASE]], %[[C0]], %[[SIZES]]#0, %[[DYN_SIZE1]], %[[C42]], %[[STRIDES]]#0, %[[C42]], %[[C1]]
+func.func @extract_strided_metadata_of_collapse(%arg : memref<?x?x4x?x6x7xi32>)
+ -> (memref<i32>, index,
+ index, index, index,
+ index, index, index) {
+
+ %collapsed_view = memref.collapse_shape %arg [[0], [1, 2, 3], [4, 5]] :
+ memref<?x?x4x?x6x7xi32> into memref<?x?x42xi32>
+
+ %base, %offset, %sizes:3, %strides:3 =
+ memref.extract_strided_metadata %collapsed_view : memref<?x?x42xi32>
+ -> memref<i32>, index,
+ index, index, index,
+ index, index, index
+
+ return %base, %offset,
+ %sizes#0, %sizes#1, %sizes#2,
+ %strides#0, %strides#1, %strides#2 :
+ memref<i32>, index,
+ index, index, index,
+ index, index, index
+
+}
+
+// -----
+
+// Check that we simplify extract_strided_metadata of collapse_shape to
+// a 0-ranked shape.
+// CHECK-LABEL: func @extract_strided_metadata_of_collapse_to_rank0(
+// CHECK-SAME: %[[ARG:.*]]: memref<1x1x1x1x1x1xi32>)
+//
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//
+// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref<1x1x1x1x1x1xi32>
+//
+// CHECK: return %[[BASE]], %[[C0]]
+func.func @extract_strided_metadata_of_collapse_to_rank0(%arg : memref<1x1x1x1x1x1xi32>)
+ -> (memref<i32>, index) {
+
+ %collapsed_view = memref.collapse_shape %arg [] :
+ memref<1x1x1x1x1x1xi32> into memref<i32>
+
+ %base, %offset =
+ memref.extract_strided_metadata %collapsed_view : memref<i32>
+ -> memref<i32>, index
+
+ return %base, %offset :
+ memref<i32>, index
+}
More information about the Mlir-commits
mailing list