[Mlir-commits] [mlir] d831568 - [mlir][MemRef] Simplify extract_strided_metadata(collapse_shape)

Quentin Colombet llvmlistbot at llvm.org
Fri Sep 30 10:08:41 PDT 2022


Author: Quentin Colombet
Date: 2022-09-30T16:54:56Z
New Revision: d8315681714222e83e32beba374a5ff976d90059

URL: https://github.com/llvm/llvm-project/commit/d8315681714222e83e32beba374a5ff976d90059
DIFF: https://github.com/llvm/llvm-project/commit/d8315681714222e83e32beba374a5ff976d90059.diff

LOG: [mlir][MemRef] Simplify extract_strided_metadata(collapse_shape)

The new pattern gets rid of the collapse_shape operation while
materializing its effects on the sizes, and the strides of the base
object.

In other words, this simplification replaces:

```
baseBuffer, offset, sizes, strides =
    extract_strided_metadata(collapse_shape(memref))
```

With

```
baseBuffer, offset, baseSizes, baseStrides =
    extract_strided_metadata(memref)
for reassDim in {0 .. collapseRank - 1}
  sizes#reassDim = product(baseSizes#i for i in group[reassDim])
  strides#reassDim = baseStrides[group[reassDim].back()]
```

Note: baseBuffer and offset are unaffected by the collapse_shape
operation.

Differential Revision: https://reviews.llvm.org/D134826

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
    mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
index 462950fd5268c..5a14578fa112a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
@@ -173,6 +173,11 @@ struct ExtractStridedMetadataOpSubviewFolder
 /// \p origSizes hold the sizes of the source shape as values.
 /// This is used to compute the new sizes in cases of dynamic shapes.
 ///
+/// sizes#i =
+///     baseSizes#groupId / product(expandShapeSizes#j,
+///                                  for j in group excluding reassIdx#i)
+/// Where reassIdx#i is the reassociation index at index i in \p groupId.
+///
 /// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
 ///
 /// TODO: Move this utility function directly within ExpandShapeOp. For now,
@@ -225,6 +230,18 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
 /// This is used to compute the strides in cases of dynamic shapes and/or
 /// dynamic stride for this reassociation group.
 ///
+/// strides#i =
+///     origStrides#reassDim * product(expandShapeSizes#j, for j in
+///                                    reassIdx#i+1..reassIdx#i+group.size-1)
+///
+/// Where reassIdx#i is the reassociation index for at index i in \p groupId
+/// and expandShapeSizes#j is either:
+/// - The constant size at dimension j, derived directly from the result type of
+///   the expand_shape op, or
+/// - An affine expression: baseSizes#reassDim / product of all constant sizes
+///   in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
+///   element.)
+///
 /// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
 ///
 /// TODO: Move this utility function directly within ExpandShapeOp. For now,
@@ -315,49 +332,162 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
   return expandedStrides;
 }
 
+/// Produce an OpFoldResult object with \p builder at \p loc representing
+/// `prod(valueOrConstant#i, for i in {indices})`,
+/// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false,
+/// values[i] otherwise.
+///
+/// \pre for all index in indices: index < values.size()
+/// \pre for all index in indices: index < maybeConstants.size()
+static OpFoldResult
+getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
+                   ArrayRef<int64_t> maybeConstants,
+                   ArrayRef<OpFoldResult> values,
+                   llvm::function_ref<bool(int64_t)> isDynamic) {
+  AffineExpr productOfValues = builder.getAffineConstantExpr(1);
+  SmallVector<OpFoldResult> inputValues;
+  unsigned numberOfSymbols = 0;
+  unsigned groupSize = indices.size();
+  for (unsigned i = 0; i < groupSize; ++i) {
+    productOfValues =
+        productOfValues * builder.getAffineSymbolExpr(numberOfSymbols++);
+    unsigned srcIdx = indices[i];
+    int64_t maybeConstant = maybeConstants[srcIdx];
+
+    inputValues.push_back(isDynamic(maybeConstant)
+                              ? values[srcIdx]
+                              : builder.getIndexAttr(maybeConstant));
+  }
+
+  return makeComposedFoldedAffineApply(builder, loc, productOfValues,
+                                       inputValues);
+}
+
+/// Compute the collapsed size of the given \p collpaseShape for the
+/// \p groupId-th reassociation group.
+/// \p origSizes hold the sizes of the source shape as values.
+/// This is used to compute the new sizes in cases of dynamic shapes.
+///
+/// Conceptually this helper function computes:
+/// `prod(origSizes#i, for i in {ressociationGroup[groupId]})`.
+///
+/// \post result.size() == 1, in other words, each group collapse to one
+/// dimension.
+///
+/// TODO: Move this utility function directly within CollapseShapeOp. For now,
+/// this is not possible because this function uses the Affine dialect and the
+/// MemRef dialect cannot depend on the Affine dialect.
+static SmallVector<OpFoldResult>
+getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
+                 ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
+  SmallVector<OpFoldResult> collapsedSize;
+
+  MemRefType collapseShapeType = collapseShape.getResultType();
+
+  uint64_t size = collapseShapeType.getDimSize(groupId);
+  if (!ShapedType::isDynamic(size)) {
+    collapsedSize.push_back(builder.getIndexAttr(size));
+    return collapsedSize;
+  }
+
+  // We are dealing with a dynamic size.
+  // Build the affine expr of the product of the original sizes involved in that
+  // group.
+  Value source = collapseShape.getSrc();
+  auto sourceType = source.getType().cast<MemRefType>();
+
+  SmallVector<int64_t, 2> reassocGroup =
+      collapseShape.getReassociationIndices()[groupId];
+
+  collapsedSize.push_back(getProductOfValues(
+      reassocGroup, builder, collapseShape.getLoc(), sourceType.getShape(),
+      origSizes, ShapedType::isDynamic));
+
+  return collapsedSize;
+}
+
+/// Compute the collapsed stride of the given \p collpaseShape for the
+/// \p groupId-th reassociation group.
+/// \p origStrides and \p origSizes hold respectively the strides and sizes
+/// of the source shape as values.
+/// This is used to compute the strides in cases of dynamic shapes and/or
+/// dynamic stride for this reassociation group.
+///
+/// Conceptually this helper function returns the stride of the inner most
+/// dimension of that group in the original shape.
+///
+/// \post result.size() == 1, in other words, each group collapse to one
+/// dimension.
+static SmallVector<OpFoldResult>
+getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
+                   ArrayRef<OpFoldResult> origSizes,
+                   ArrayRef<OpFoldResult> origStrides, unsigned groupId) {
+  SmallVector<int64_t, 2> reassocGroup =
+      collapseShape.getReassociationIndices()[groupId];
+  assert(!reassocGroup.empty() &&
+         "Reassociation group should have at least one dimension");
+
+  Value source = collapseShape.getSrc();
+  auto sourceType = source.getType().cast<MemRefType>();
+
+  SmallVector<int64_t> strides;
+  int64_t offset;
+  bool hasKnownStridesAndOffset =
+      succeeded(getStridesAndOffset(sourceType, strides, offset));
+  (void)hasKnownStridesAndOffset;
+  assert(hasKnownStridesAndOffset &&
+         "getStridesAndOffset must work on valid collapse_shape");
+
+  SmallVector<OpFoldResult> collapsedStride;
+  int64_t innerMostDimForGroup = reassocGroup.back();
+  int64_t innerMostStrideForGroup = strides[innerMostDimForGroup];
+  collapsedStride.push_back(
+      ShapedType::isDynamicStrideOrOffset(innerMostStrideForGroup)
+          ? origStrides[innerMostDimForGroup]
+          : builder.getIndexAttr(innerMostStrideForGroup));
+
+  return collapsedStride;
+}
 /// Replace `baseBuffer, offset, sizes, strides =
-///              extract_strided_metadata(expand_shape(memref))`
+///              extract_strided_metadata(reshapeLike(memref))`
 /// With
 ///
 /// \verbatim
 /// baseBuffer, offset, baseSizes, baseStrides =
 ///     extract_strided_metadata(memref)
-/// sizes#reassIdx =
-///     baseSizes#reassDim / product(expandShapeSizes#j,
-///                                  for j in group excluding reassIdx)
-/// strides#reassIdx =
-///     baseStrides#reassDim * product(expandShapeSizes#j, for j in
-///                                    reassIdx+1..reassIdx+group.size-1)
+/// sizes = getReshapedSizes(reshapeLike)
+/// strides = getReshapedStrides(reshapeLike)
 /// \endverbatim
 ///
-/// Where reassIdx is a reassociation index for the group at reassDim
-/// and expandShapeSizes#j is either:
-/// - The constant size at dimension j, derived directly from the result type of
-///   the expand_shape op, or
-/// - An affine expression: baseSizes#reassDim / product of all constant sizes
-///   in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
-///   element.)
 ///
 /// Notice that `baseBuffer` and `offset` are unchanged.
 ///
 /// In other words, get rid of the expand_shape in that expression and
 /// materialize its effects on the sizes and the strides using affine apply.
-struct ExtractStridedMetadataOpExpandShapeFolder
+template <typename ReassociativeReshapeLikeOp,
+          SmallVector<OpFoldResult> (*getReshapedSizes)(
+              ReassociativeReshapeLikeOp, OpBuilder &,
+              ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/),
+          SmallVector<OpFoldResult> (*getReshapedStrides)(
+              ReassociativeReshapeLikeOp, OpBuilder &,
+              ArrayRef<OpFoldResult> /*origSizes*/,
+              ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)>
+struct ExtractStridedMetadataOpReshapeFolder
     : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
 public:
   using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
                                 PatternRewriter &rewriter) const override {
-    auto expandShape = op.getSource().getDefiningOp<memref::ExpandShapeOp>();
-    if (!expandShape)
+    auto reshape = op.getSource().getDefiningOp<ReassociativeReshapeLikeOp>();
+    if (!reshape)
       return failure();
 
     // Build a plain extract_strided_metadata(memref) from
-    // extract_strided_metadata(expand_shape(memref)).
+    // extract_strided_metadata(reassociative_reshape_like(memref)).
     Location origLoc = op.getLoc();
     IndexType indexType = rewriter.getIndexType();
-    Value source = expandShape.getSrc();
+    Value source = reshape.getSrc();
     auto sourceType = source.getType().cast<MemRefType>();
     unsigned sourceRank = sourceType.getRank();
     SmallVector<Type> sizeStrideTypes(sourceRank, indexType);
@@ -374,13 +504,13 @@ struct ExtractStridedMetadataOpExpandShapeFolder
         succeeded(getStridesAndOffset(sourceType, strides, offset));
     (void)hasKnownStridesAndOffset;
     assert(hasKnownStridesAndOffset &&
-           "getStridesAndOffset must work on valid expand_shape");
-    MemRefType expandShapeType = expandShape.getResultType();
-    unsigned expandShapeRank = expandShapeType.getRank();
+           "getStridesAndOffset must work on valid reassociative_reshape_like");
+    MemRefType reshapeType = reshape.getResultType();
+    unsigned reshapeRank = reshapeType.getRank();
 
     // The result value will start with the base_buffer and offset.
     unsigned baseIdxInResult = 2;
-    SmallVector<OpFoldResult> results(baseIdxInResult + expandShapeRank * 2);
+    SmallVector<OpFoldResult> results(baseIdxInResult + reshapeRank * 2);
     results[0] = newExtractStridedMetadata.getBaseBuffer();
     results[1] = ShapedType::isDynamicStrideOrOffset(offset)
                      ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
@@ -390,7 +520,7 @@ struct ExtractStridedMetadataOpExpandShapeFolder
     if (sourceRank == 0) {
       Value constantOne = getValueOrCreateConstantIndexOp(
           rewriter, origLoc, rewriter.getIndexAttr(1));
-      SmallVector<Value> resultValues(baseIdxInResult + expandShapeRank * 2,
+      SmallVector<Value> resultValues(baseIdxInResult + reshapeRank * 2,
                                       constantOne);
       for (unsigned i = 0; i < baseIdxInResult; ++i)
         resultValues[i] =
@@ -399,30 +529,31 @@ struct ExtractStridedMetadataOpExpandShapeFolder
       return success();
     }
 
-    // Compute the expanded strides and sizes from the base strides and sizes.
+    // Compute the reshaped strides and sizes from the base strides and sizes.
     SmallVector<OpFoldResult> origSizes =
         getAsOpFoldResult(newExtractStridedMetadata.getSizes());
     SmallVector<OpFoldResult> origStrides =
         getAsOpFoldResult(newExtractStridedMetadata.getStrides());
-    unsigned idx = 0, endIdx = expandShape.getReassociationIndices().size();
+    unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
     for (; idx != endIdx; ++idx) {
-      SmallVector<OpFoldResult> expandedSizes =
-          getExpandedSizes(expandShape, rewriter, origSizes, /*groupId=*/idx);
-      SmallVector<OpFoldResult> expandedStrides = getExpandedStrides(
-          expandShape, rewriter, origSizes, origStrides, /*groupId=*/idx);
+      SmallVector<OpFoldResult> reshapedSizes =
+          getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx);
+      SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides(
+          reshape, rewriter, origSizes, origStrides, /*groupId=*/idx);
 
-      unsigned groupSize = expandShape.getReassociationIndices()[idx].size();
+      unsigned groupSize = reshapedSizes.size();
       const unsigned sizeStartIdx = baseIdxInResult;
-      const unsigned strideStartIdx = sizeStartIdx + expandShapeRank;
+      const unsigned strideStartIdx = sizeStartIdx + reshapeRank;
       for (unsigned i = 0; i < groupSize; ++i) {
-        results[sizeStartIdx + i] = expandedSizes[i];
-        results[strideStartIdx + i] = expandedStrides[i];
+        results[sizeStartIdx + i] = reshapedSizes[i];
+        results[strideStartIdx + i] = reshapedStrides[i];
       }
       baseIdxInResult += groupSize;
     }
-    assert(idx == sourceRank &&
+    assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
+            (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
            "We should have visited all the input dimensions");
-    assert(baseIdxInResult == expandShapeRank + 2 &&
+    assert(baseIdxInResult == reshapeRank + 2 &&
            "We should have populated all the values");
     rewriter.replaceOp(
         op, getValueOrCreateConstantIndexOp(rewriter, origLoc, results));
@@ -599,12 +730,17 @@ class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
 
 void memref::populateSimplifyExtractStridedMetadataOpPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<ExtractStridedMetadataOpSubviewFolder,
-               ExtractStridedMetadataOpExpandShapeFolder, ForwardStaticMetadata,
-               ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
-               ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
-               RewriteExtractAlignedPointerAsIndexOfViewLikeOp>(
-      patterns.getContext());
+  patterns
+      .add<ExtractStridedMetadataOpSubviewFolder,
+           ExtractStridedMetadataOpReshapeFolder<
+               memref::ExpandShapeOp, getExpandedSizes, getExpandedStrides>,
+           ExtractStridedMetadataOpReshapeFolder<
+               memref::CollapseShapeOp, getCollapsedSize, getCollapsedStride>,
+           ForwardStaticMetadata,
+           ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
+           ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
+           RewriteExtractAlignedPointerAsIndexOfViewLikeOp>(
+          patterns.getContext());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
index 2862ef96a53c8..616b835842910 100644
--- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
@@ -762,3 +762,82 @@ func.func @extract_aligned_pointer_as_index(%arg0: memref<f32>) -> index {
   %r = memref.extract_aligned_pointer_as_index %arg0: memref<f32> -> index
   return %r : index
 }
+
+// -----
+
+// Check that we simplify extract_strided_metadata of collapse_shape.
+//
+// We transform: ?x?x4x?x6x7xi32 to [0][1,2,3][4,5]
+// Size 0 = origSize0
+// Size 1 = origSize1 * origSize2 * origSize3
+//        = origSize1 * 4 * origSize3
+// Size 2 = origSize4 * origSize5
+//        = 6 * 7
+//        = 42
+// Stride 0 = origStride0
+// Stride 1 = origStride3 (orig stride of the inner most dimension)
+//          = 42
+// Stride 2 = origStride5
+//          = 1
+//
+//   CHECK-DAG: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
+// CHECK-LABEL: func @extract_strided_metadata_of_collapse(
+//  CHECK-SAME: %[[ARG:.*]]: memref<?x?x4x?x6x7xi32>)
+//
+//   CHECK-DAG: %[[C42:.*]] = arith.constant 42 : index
+//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//
+//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref<?x?x4x?x6x7xi32>
+//
+//   CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3]
+//
+//       CHECK: return %[[BASE]], %[[C0]], %[[SIZES]]#0, %[[DYN_SIZE1]], %[[C42]], %[[STRIDES]]#0, %[[C42]], %[[C1]]
+func.func @extract_strided_metadata_of_collapse(%arg : memref<?x?x4x?x6x7xi32>)
+  -> (memref<i32>, index,
+      index, index, index,
+      index, index, index) {
+
+  %collapsed_view = memref.collapse_shape %arg [[0], [1, 2, 3], [4, 5]] :
+    memref<?x?x4x?x6x7xi32> into memref<?x?x42xi32>
+
+  %base, %offset, %sizes:3, %strides:3 =
+    memref.extract_strided_metadata %collapsed_view : memref<?x?x42xi32>
+    -> memref<i32>, index,
+       index, index, index,
+       index, index, index
+
+  return %base, %offset,
+    %sizes#0, %sizes#1, %sizes#2,
+    %strides#0, %strides#1, %strides#2 :
+      memref<i32>, index,
+      index, index, index,
+      index, index, index
+
+}
+
+// -----
+
+// Check that we simplify extract_strided_metadata of collapse_shape to
+// a 0-ranked shape.
+// CHECK-LABEL: func @extract_strided_metadata_of_collapse_to_rank0(
+//  CHECK-SAME: %[[ARG:.*]]: memref<1x1x1x1x1x1xi32>)
+//
+//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//
+//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref<1x1x1x1x1x1xi32>
+//
+//       CHECK: return %[[BASE]], %[[C0]]
+func.func @extract_strided_metadata_of_collapse_to_rank0(%arg : memref<1x1x1x1x1x1xi32>)
+  -> (memref<i32>, index) {
+
+  %collapsed_view = memref.collapse_shape %arg [] :
+    memref<1x1x1x1x1x1xi32> into memref<i32>
+
+  %base, %offset =
+    memref.extract_strided_metadata %collapsed_view : memref<i32>
+    -> memref<i32>, index
+
+  return %base, %offset :
+      memref<i32>, index
+}


        


More information about the Mlir-commits mailing list