[Mlir-commits] [mlir] d0aeb74 - [mlir][MemRef] Simplify extract_strided_metadata(expand_shape)
Quentin Colombet
llvmlistbot at llvm.org
Thu Sep 22 12:14:41 PDT 2022
Author: Quentin Colombet
Date: 2022-09-22T19:07:09Z
New Revision: d0aeb74e8869e5db23c079b98c5e1f325aeeeefe
URL: https://github.com/llvm/llvm-project/commit/d0aeb74e8869e5db23c079b98c5e1f325aeeeefe
DIFF: https://github.com/llvm/llvm-project/commit/d0aeb74e8869e5db23c079b98c5e1f325aeeeefe.diff
LOG: [mlir][MemRef] Simplify extract_strided_metadata(expand_shape)
Add a pattern to the pass that simplifies
extract_strided_metadata(other_op(memref)).
The new pattern gets rid of the expand_shape operation while
materializing its effects on the sizes, and the strides of
the base object.
In other words, this simplification replaces:
```
baseBuffer, offset, sizes, strides =
extract_strided_metadata(expand_shape(memref))
```
With
```
baseBuffer, offset, baseSizes, baseStrides =
extract_strided_metadata(memref)
sizes#reassIdx =
baseSizes#reassDim / product(expandShapeSizes#j,
for j in group excluding
reassIdx)
strides#reassIdx =
baseStrides#reassDim * product(expandShapeSizes#j,
for j in
reassIdx+1..
reassIdx+group.size-1)
```
Where `reassIdx` is a reassociation index for the group at
`reassDim` and `expandShapeSizes#j` is either:
- The constant size at dimension j, derived directly from the
result type of the expand_shape op, or
- An affine expression: baseSizes#reassDim / product of all
constant sizes in expandShapeSizes.
Note: baseBuffer and offset are unaffected by the expand_shape
operation.
Differential Revision: https://reviews.llvm.org/D133625
Added:
Modified:
mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
index 3cad0af9b9be2..16d17aa0b183b 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
@@ -43,7 +43,7 @@ namespace {
/// \endverbatim
///
/// In other words, get rid of the subview in that expression and canonicalize
-/// on its effects on the offset, the sizes, and the strides using affine apply.
+/// on its effects on the offset, the sizes, and the strides using affine.apply.
struct ExtractStridedMetadataOpSubviewFolder
: public OpRewritePattern<memref::ExtractStridedMetadataOp> {
public:
@@ -166,11 +166,275 @@ struct ExtractStridedMetadataOpSubviewFolder
return success();
}
};
+
+/// Compute the expanded sizes of the given \p expandShape for the
+/// \p groupId-th reassociation group.
+/// \p origSizes hold the sizes of the source shape as values.
+/// This is used to compute the new sizes in cases of dynamic shapes.
+///
+/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
+///
+/// TODO: Move this utility function directly within ExpandShapeOp. For now,
+/// this is not possible because this function uses the Affine dialect and the
+/// MemRef dialect cannot depend on the Affine dialect.
+static SmallVector<OpFoldResult>
+getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
+ SmallVector<int64_t, 2> reassocGroup =
+ expandShape.getReassociationIndices()[groupId];
+ assert(!reassocGroup.empty() &&
+ "Reassociation group should have at least one dimension");
+
+ unsigned groupSize = reassocGroup.size();
+ SmallVector<OpFoldResult> expandedSizes(groupSize);
+
+ uint64_t productOfAllStaticSizes = 1;
+ Optional<unsigned> dynSizeIdx;
+ MemRefType expandShapeType = expandShape.getResultType();
+
+ // Fill up all the statically known sizes.
+ for (unsigned i = 0; i < groupSize; ++i) {
+ uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
+ if (ShapedType::isDynamic(dimSize)) {
+ assert(!dynSizeIdx && "There must be at most one dynamic size per group");
+ dynSizeIdx = i;
+ continue;
+ }
+ productOfAllStaticSizes *= dimSize;
+ expandedSizes[i] = builder.getIndexAttr(dimSize);
+ }
+
+ // Compute the dynamic size using the original size and all the other known
+ // static sizes:
+ // expandSize = origSize / productOfAllStaticSizes.
+ if (dynSizeIdx) {
+ AffineExpr s0 = builder.getAffineSymbolExpr(0);
+ expandedSizes[*dynSizeIdx] = makeComposedFoldedAffineApply(
+ builder, expandShape.getLoc(), s0.floorDiv(productOfAllStaticSizes),
+ origSizes[groupId]);
+ }
+
+ return expandedSizes;
+}
+
+/// Compute the expanded strides of the given \p expandShape for the
+/// \p groupId-th reassociation group.
+/// \p origStrides and \p origSizes hold respectively the strides and sizes
+/// of the source shape as values.
+/// This is used to compute the strides in cases of dynamic shapes and/or
+/// dynamic stride for this reassociation group.
+///
+/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
+///
+/// TODO: Move this utility function directly within ExpandShapeOp. For now,
+/// this is not possible because this function uses the Affine dialect and the
+/// MemRef dialect cannot depend on the Affine dialect.
+SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
+ OpBuilder &builder,
+ ArrayRef<OpFoldResult> origSizes,
+ ArrayRef<OpFoldResult> origStrides,
+ unsigned groupId) {
+ SmallVector<int64_t, 2> reassocGroup =
+ expandShape.getReassociationIndices()[groupId];
+ assert(!reassocGroup.empty() &&
+ "Reassociation group should have at least one dimension");
+
+ unsigned groupSize = reassocGroup.size();
+ MemRefType expandShapeType = expandShape.getResultType();
+
+ Optional<int64_t> dynSizeIdx;
+
+ // Fill up the expanded strides, with the information we can deduce from the
+ // resulting shape.
+ uint64_t currentStride = 1;
+ SmallVector<OpFoldResult> expandedStrides(groupSize);
+ for (int i = groupSize - 1; i >= 0; --i) {
+ expandedStrides[i] = builder.getIndexAttr(currentStride);
+ uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
+ if (ShapedType::isDynamic(dimSize)) {
+ assert(!dynSizeIdx && "There must be at most one dynamic size per group");
+ dynSizeIdx = i;
+ continue;
+ }
+
+ currentStride *= dimSize;
+ }
+
+ // Collect the statically known information about the original stride.
+ Value source = expandShape.getSrc();
+ auto sourceType = source.getType().cast<MemRefType>();
+ SmallVector<int64_t> strides;
+ int64_t offset;
+ bool hasKnownStridesAndOffset =
+ succeeded(getStridesAndOffset(sourceType, strides, offset));
+ (void)hasKnownStridesAndOffset;
+ assert(hasKnownStridesAndOffset &&
+ "getStridesAndOffset must work on valid expand_shape");
+
+ OpFoldResult origStride =
+ ShapedType::isDynamicStrideOrOffset(strides[groupId])
+ ? origStrides[groupId]
+ : builder.getIndexAttr(strides[groupId]);
+
+ // Apply the original stride to all the strides.
+ int64_t doneStrideIdx = 0;
+ // If we saw a dynamic dimension, we need to fix-up all the strides up to
+ // that dimension with the dynamic size.
+ if (dynSizeIdx) {
+ int64_t productOfAllStaticSizes = currentStride;
+ assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
+ "We shouldn't be able to change dynamicity");
+ OpFoldResult origSize = origSizes[groupId];
+
+ AffineExpr s0 = builder.getAffineSymbolExpr(0);
+ AffineExpr s1 = builder.getAffineSymbolExpr(1);
+ for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
+ int64_t baseExpandedStride = expandedStrides[doneStrideIdx]
+ .get<Attribute>()
+ .cast<IntegerAttr>()
+ .getInt();
+ expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
+ builder, expandShape.getLoc(),
+ (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
+ {origSize, origStride});
+ }
+ }
+
+ // Now apply the origStride to the remaining dimensions.
+ AffineExpr s0 = builder.getAffineSymbolExpr(0);
+ for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
+ int64_t baseExpandedStride = expandedStrides[doneStrideIdx]
+ .get<Attribute>()
+ .cast<IntegerAttr>()
+ .getInt();
+ expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
+ builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
+ }
+
+ return expandedStrides;
+}
+
+/// Replace `baseBuffer, offset, sizes, strides =
+/// extract_strided_metadata(expand_shape(memref))`
+/// With
+///
+/// \verbatim
+/// baseBuffer, offset, baseSizes, baseStrides =
+/// extract_strided_metadata(memref)
+/// sizes#reassIdx =
+/// baseSizes#reassDim / product(expandShapeSizes#j,
+/// for j in group excluding reassIdx)
+/// strides#reassIdx =
+/// baseStrides#reassDim * product(expandShapeSizes#j, for j in
+/// reassIdx+1..reassIdx+group.size-1)
+/// \endverbatim
+///
+/// Where reassIdx is a reassociation index for the group at reassDim
+/// and expandShapeSizes#j is either:
+/// - The constant size at dimension j, derived directly from the result type of
+/// the expand_shape op, or
+/// - An affine expression: baseSizes#reassDim / product of all constant sizes
+/// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
+/// element.)
+///
+/// Notice that `baseBuffer` and `offset` are unchanged.
+///
+/// In other words, get rid of the expand_shape in that expression and
+/// materialize its effects on the sizes and the strides using affine apply.
+struct ExtractStridedMetadataOpExpandShapeFolder
+ : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
+public:
+ using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
+ PatternRewriter &rewriter) const override {
+ auto expandShape = op.getSource().getDefiningOp<memref::ExpandShapeOp>();
+ if (!expandShape)
+ return failure();
+
+ // Build a plain extract_strided_metadata(memref) from
+ // extract_strided_metadata(expand_shape(memref)).
+ Location origLoc = op.getLoc();
+ IndexType indexType = rewriter.getIndexType();
+ Value source = expandShape.getSrc();
+ auto sourceType = source.getType().cast<MemRefType>();
+ unsigned sourceRank = sourceType.getRank();
+ SmallVector<Type> sizeStrideTypes(sourceRank, indexType);
+
+ auto newExtractStridedMetadata =
+ rewriter.create<memref::ExtractStridedMetadataOp>(
+ origLoc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes,
+ sizeStrideTypes, source);
+
+ // Collect statically known information.
+ SmallVector<int64_t> strides;
+ int64_t offset;
+ bool hasKnownStridesAndOffset =
+ succeeded(getStridesAndOffset(sourceType, strides, offset));
+ (void)hasKnownStridesAndOffset;
+ assert(hasKnownStridesAndOffset &&
+ "getStridesAndOffset must work on valid expand_shape");
+ MemRefType expandShapeType = expandShape.getResultType();
+ unsigned expandShapeRank = expandShapeType.getRank();
+
+ // The result value will start with the base_buffer and offset.
+ unsigned baseIdxInResult = 2;
+ SmallVector<OpFoldResult> results(baseIdxInResult + expandShapeRank * 2);
+ results[0] = newExtractStridedMetadata.getBaseBuffer();
+ results[1] = ShapedType::isDynamicStrideOrOffset(offset)
+ ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
+ : rewriter.getIndexAttr(offset);
+
+ // Get the special case of 0-D out of the way.
+ if (sourceRank == 0) {
+ Value constantOne = getValueOrCreateConstantIndexOp(
+ rewriter, origLoc, rewriter.getIndexAttr(1));
+ SmallVector<Value> resultValues(baseIdxInResult + expandShapeRank * 2,
+ constantOne);
+ for (unsigned i = 0; i < baseIdxInResult; ++i)
+ resultValues[i] =
+ getValueOrCreateConstantIndexOp(rewriter, origLoc, results[i]);
+ rewriter.replaceOp(op, resultValues);
+ return success();
+ }
+
+ // Compute the expanded strides and sizes from the base strides and sizes.
+ SmallVector<OpFoldResult> origSizes =
+ getAsOpFoldResult(newExtractStridedMetadata.getSizes());
+ SmallVector<OpFoldResult> origStrides =
+ getAsOpFoldResult(newExtractStridedMetadata.getStrides());
+ unsigned idx = 0, endIdx = expandShape.getReassociationIndices().size();
+ for (; idx != endIdx; ++idx) {
+ SmallVector<OpFoldResult> expandedSizes =
+ getExpandedSizes(expandShape, rewriter, origSizes, /*groupId=*/idx);
+ SmallVector<OpFoldResult> expandedStrides = getExpandedStrides(
+ expandShape, rewriter, origSizes, origStrides, /*groupId=*/idx);
+
+ unsigned groupSize = expandShape.getReassociationIndices()[idx].size();
+ const unsigned sizeStartIdx = baseIdxInResult;
+ const unsigned strideStartIdx = sizeStartIdx + expandShapeRank;
+ for (unsigned i = 0; i < groupSize; ++i) {
+ results[sizeStartIdx + i] = expandedSizes[i];
+ results[strideStartIdx + i] = expandedStrides[i];
+ }
+ baseIdxInResult += groupSize;
+ }
+ assert(idx == sourceRank &&
+ "We should have visited all the input dimensions");
+ assert(baseIdxInResult == expandShapeRank + 2 &&
+ "We should have populated all the values");
+ rewriter.replaceOp(
+ op, getValueOrCreateConstantIndexOp(rewriter, origLoc, results));
+ return success();
+ }
+};
} // namespace
void memref::populateSimplifyExtractStridedMetadataOpPatterns(
RewritePatternSet &patterns) {
- patterns.add<ExtractStridedMetadataOpSubviewFolder>(patterns.getContext());
+ patterns.add<ExtractStridedMetadataOpSubviewFolder,
+ ExtractStridedMetadataOpExpandShapeFolder>(
+ patterns.getContext());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
index 0daeb4a23a1f7..3b2b00d2dc6fa 100644
--- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
@@ -280,3 +280,231 @@ func.func @extract_strided_metadata_of_subview_all_dynamic(
return %base_buffer, %offset, %sizes#0, %sizes#1, %sizes#2, %strides#0, %strides#1, %strides#2 :
memref<f32>, index, index, index, index, index, index, index
}
+
+// -----
+
+// Check that we properly simplify extract_strided_metadata of expand_shape
+// into:
+// baseBuffer, baseOffset, baseSizes, baseStrides =
+// extract_strided_metadata(memref)
+// sizes#reassIdx =
+// baseSizes#reassDim / product(expandShapeSizes#j,
+// for j in group excluding reassIdx)
+// strides#reassIdx =
+// baseStrides#reassDim * product(expandShapeSizes#j, for j in
+// reassIdx+1..reassIdx+group.size)
+//
+// Here we have:
+// For the group applying to dim0:
+// size 0 = 3
+// size 1 = 5
+// size 2 = 2
+// stride 0 = baseStrides#0 * 5 * 2
+// = 4 * 5 * 2
+// = 40
+// stride 1 = baseStrides#0 * 2
+// = 4 * 2
+// = 8
+// stride 2 = baseStrides#0
+// = 4
+//
+// For the group applying to dim1:
+// size 3 = 2
+// size 4 = 2
+// stride 3 = baseStrides#1 * 2
+// = 1 * 2
+// = 2
+// stride 4 = baseStrides#1
+// = 1
+//
+// Base and offset are unchanged.
+//
+// CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_static
+// CHECK-SAME: (%[[ARG:.*]]: memref<30x4xi16>)
+//
+// CHECK-DAG: %[[C40:.*]] = arith.constant 40 : index
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//
+// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<30x4xi16> -> memref<i16>, index, index, index, index, index
+//
+// CHECK: return %[[BASE]], %[[C0]], %[[C3]], %[[C5]], %[[C2]], %[[C2]], %[[C2]], %[[C40]], %[[C8]], %[[C4]], %[[C2]], %[[C1]] : memref<i16>, index, index, index, index, index, index, index, index, index, index, index
+func.func @extract_strided_metadata_of_expand_shape_all_static(
+ %arg : memref<30x4xi16>)
+ -> (memref<i16>, index,
+ index, index, index, index, index,
+ index, index, index, index, index) {
+
+ %expand_shape = memref.expand_shape %arg[[0, 1, 2], [3, 4]] :
+ memref<30x4xi16> into memref<3x5x2x2x2xi16>
+
+ %base, %offset, %sizes:5, %strides:5 = memref.extract_strided_metadata %expand_shape :
+ memref<3x5x2x2x2xi16>
+ -> memref<i16>, index,
+ index, index, index, index, index,
+ index, index, index, index, index
+
+ return %base, %offset,
+ %sizes#0, %sizes#1, %sizes#2, %sizes#3, %sizes#4,
+ %strides#0, %strides#1, %strides#2, %strides#3, %strides#4 :
+ memref<i16>, index,
+ index, index, index, index, index,
+ index, index, index, index, index
+}
+
+// -----
+
+// Check that we properly simplify extract_strided_metadata of expand_shape
+// when dynamic sizes, strides, and offsets are involved.
+// See extract_strided_metadata_of_expand_shape_all_static for an explanation
+// of the expansion.
+//
+// One of the important characteristic of this test is that the dynamic
+// dimensions produced by the expand_shape appear both in the first dimension
+// (for group 1) and the non-first dimension (second dimension for group 2.)
+// The idea is to make sure that:
+// 1. We properly account for dynamic shapes even when the strides are not
+// affected by them. (When the dynamic dimension is the first one.)
+// 2. We properly compute the strides affected by dynamic shapes. (When the
+// dynamic dimension is not the first one.)
+//
+// Here we have:
+// For the group applying to dim0:
+// size 0 = baseSizes#0 / (all static sizes in that group)
+// = baseSizes#0 / (7 * 8 * 9)
+// = baseSizes#0 / 504
+// size 1 = 7
+// size 2 = 8
+// size 3 = 9
+// stride 0 = baseStrides#0 * 7 * 8 * 9
+// = baseStrides#0 * 504
+// stride 1 = baseStrides#0 * 8 * 9
+// = baseStrides#0 * 72
+// stride 2 = baseStrides#0 * 9
+// stride 3 = baseStrides#0
+//
+// For the group applying to dim1:
+// size 4 = 10
+// size 5 = 2
+// size 6 = baseSizes#1 / (all static sizes in that group)
+// = baseSizes#1 / (10 * 2 * 3)
+// = baseSizes#1 / 60
+// size 7 = 3
+// stride 4 = baseStrides#1 * size 5 * size 6 * size 7
+// = baseStrides#1 * 2 * (baseSizes#1 / 60) * 3
+// = baseStrides#1 * (baseSizes#1 / 60) * 6
+// and since we know that baseSizes#1 is a multiple of 60:
+// = baseStrides#1 * (baseSizes#1 / 10)
+// stride 5 = baseStrides#1 * size 6 * size 7
+// = baseStrides#1 * (baseSizes#1 / 60) * 3
+// = baseStrides#1 * (baseSizes#1 / 20)
+// stride 6 = baseStrides#1 * size 7
+// = baseStrides#1 * 3
+// stride 7 = baseStrides#1
+//
+// Base and offset are unchanged.
+//
+// CHECK-DAG: #[[$DIM0_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 504)>
+// CHECK-DAG: #[[$DIM6_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 60)>
+//
+// CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 504)>
+// CHECK-DAG: #[[$DIM1_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 72)>
+// CHECK-DAG: #[[$DIM2_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 9)>
+// CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 10) * s1)>
+// CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 20) * s1)>
+// CHECK-DAG: #[[$DIM6_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 3)>
+// CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_dynamic
+// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32,
+//
+// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
+// CHECK-DAG: %[[C9:.*]] = arith.constant 9 : index
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+//
+// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> -> memref<f32>, index, index, index, index, index
+//
+// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$DIM0_SIZE_MAP]]()[%[[SIZES]]#0]
+// CHECK-DAG: %[[DYN_SIZE6:.*]] = affine.apply #[[$DIM6_SIZE_MAP]]()[%[[SIZES]]#1]
+// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[STRIDES]]#0]
+// CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #[[$DIM1_STRIDE_MAP]]()[%[[STRIDES]]#0]
+// CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #[[$DIM2_STRIDE_MAP]]()[%[[STRIDES]]#0]
+// CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
+// CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
+// CHECK-DAG: %[[DYN_STRIDE6:.*]] = affine.apply #[[$DIM6_STRIDE_MAP]]()[%[[STRIDES]]#1]
+
+// CHECK: return %[[BASE]], %[[OFFSET]], %[[DYN_SIZE0]], %[[C7]], %[[C8]], %[[C9]], %[[C10]], %[[C2]], %[[DYN_SIZE6]], %[[C3]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1 : memref<f32>, index, index, index, index, index, index, index, index, index, index, index, index, index
+func.func @extract_strided_metadata_of_expand_shape_all_dynamic(
+ %base: memref<?x?xf32, strided<[?,?], offset:?>>,
+ %offset0: index, %offset1: index, %offset2: index,
+ %size0: index, %size1: index, %size2: index,
+ %stride0: index, %stride1: index, %stride2: index)
+ -> (memref<f32>, index,
+ index, index, index, index, index, index, index, index,
+ index, index, index, index, index, index, index, index) {
+
+ %subview = memref.expand_shape %base[[0, 1, 2, 3],[4, 5, 6, 7]] :
+ memref<?x?xf32, strided<[?,?], offset: ?>> into
+ memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
+
+ %base_buffer, %offset, %sizes:8, %strides:8 = memref.extract_strided_metadata %subview :
+ memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
+ -> memref<f32>, index,
+ index, index, index, index, index, index, index, index,
+ index, index, index, index, index, index, index, index
+
+ return %base_buffer, %offset,
+ %sizes#0, %sizes#1, %sizes#2, %sizes#3, %sizes#4, %sizes#5, %sizes#6, %sizes#7,
+ %strides#0, %strides#1, %strides#2, %strides#3, %strides#4, %strides#5, %strides#6, %strides#7 :
+ memref<f32>, index,
+ index, index, index, index, index, index, index, index,
+ index, index, index, index, index, index, index, index
+}
+
+
+// -----
+
+// Check that we properly handle extract_strided_metadata of expand_shape for
+// 0-D input.
+// The 0-D case is pretty boring:
+// All expanded sizes are 1, likewise for the strides, and we keep the
+// original base and offset.
+// We have still a test for it, because since the input reassociation map
+// of the expand_shape is empty, the handling of such shape hits a corner
+// case.
+// CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_static_0_rank
+// CHECK-SAME: (%[[ARG:.*]]: memref<i16, strided<[], offset: ?>>)
+//
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//
+// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[ARG]] : memref<i16, strided<[], offset: ?>> -> memref<i16>, index
+//
+// CHECK: return %[[BASE]], %[[OFFSET]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]] : memref<i16>, index, index, index, index, index, index, index, index, index, index, index
+func.func @extract_strided_metadata_of_expand_shape_all_static_0_rank(
+ %arg : memref<i16, strided<[], offset: ?>>)
+ -> (memref<i16>, index,
+ index, index, index, index, index,
+ index, index, index, index, index) {
+
+ %expand_shape = memref.expand_shape %arg[] :
+ memref<i16, strided<[], offset: ?>> into memref<1x1x1x1x1xi16, strided<[1,1,1,1,1], offset: ?>>
+
+ %base, %offset, %sizes:5, %strides:5 = memref.extract_strided_metadata %expand_shape :
+ memref<1x1x1x1x1xi16, strided<[1,1,1,1,1], offset: ?>>
+ -> memref<i16>, index,
+ index, index, index, index, index,
+ index, index, index, index, index
+
+ return %base, %offset,
+ %sizes#0, %sizes#1, %sizes#2, %sizes#3, %sizes#4,
+ %strides#0, %strides#1, %strides#2, %strides#3, %strides#4 :
+ memref<i16>, index,
+ index, index, index, index, index,
+ index, index, index, index, index
+}
More information about the Mlir-commits
mailing list