[Mlir-commits] [mlir] d0aeb74 - [mlir][MemRef] Simplify extract_strided_metadata(expand_shape)

Quentin Colombet llvmlistbot at llvm.org
Thu Sep 22 12:14:41 PDT 2022


Author: Quentin Colombet
Date: 2022-09-22T19:07:09Z
New Revision: d0aeb74e8869e5db23c079b98c5e1f325aeeeefe

URL: https://github.com/llvm/llvm-project/commit/d0aeb74e8869e5db23c079b98c5e1f325aeeeefe
DIFF: https://github.com/llvm/llvm-project/commit/d0aeb74e8869e5db23c079b98c5e1f325aeeeefe.diff

LOG: [mlir][MemRef] Simplify extract_strided_metadata(expand_shape)

Add a pattern to the pass that simplifies
extract_strided_metadata(other_op(memref)).

The new pattern gets rid of the expand_shape operation while
materializing its effects on the sizes, and the strides of
the base object.

In other words, this simplification replaces:
```
baseBuffer, offset, sizes, strides =
             extract_strided_metadata(expand_shape(memref))
```

With

```
baseBuffer, offset, baseSizes, baseStrides =
    extract_strided_metadata(memref)
sizes#reassIdx =
    baseSizes#reassDim / product(expandShapeSizes#j,
                                 for j in group excluding
                                   reassIdx)
strides#reassIdx =
    baseStrides#reassDim * product(expandShapeSizes#j,
                                   for j in
                                     reassIdx+1..
                                       reassIdx+group.size-1)
```

Where `reassIdx` is a reassociation index for the group at
`reassDim` and `expandShapeSizes#j` is either:
- The constant size at dimension j, derived directly from the
  result type of the expand_shape op, or
- An affine expression: baseSizes#reassDim / product of all
  constant sizes in expandShapeSizes.

Note: baseBuffer and offset are unaffected by the expand_shape
operation.

Differential Revision: https://reviews.llvm.org/D133625

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
    mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
index 3cad0af9b9be2..16d17aa0b183b 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
@@ -43,7 +43,7 @@ namespace {
 /// \endverbatim
 ///
 /// In other words, get rid of the subview in that expression and canonicalize
-/// on its effects on the offset, the sizes, and the strides using affine apply.
+/// on its effects on the offset, the sizes, and the strides using affine.apply.
 struct ExtractStridedMetadataOpSubviewFolder
     : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
 public:
@@ -166,11 +166,275 @@ struct ExtractStridedMetadataOpSubviewFolder
     return success();
   }
 };
+
+/// Compute the expanded sizes of the given \p expandShape for the
+/// \p groupId-th reassociation group.
+/// \p origSizes hold the sizes of the source shape as values.
+/// This is used to compute the new sizes in cases of dynamic shapes.
+///
+/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
+///
+/// TODO: Move this utility function directly within ExpandShapeOp. For now,
+/// this is not possible because this function uses the Affine dialect and the
+/// MemRef dialect cannot depend on the Affine dialect.
+static SmallVector<OpFoldResult>
+getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
+                 ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
+  SmallVector<int64_t, 2> reassocGroup =
+      expandShape.getReassociationIndices()[groupId];
+  assert(!reassocGroup.empty() &&
+         "Reassociation group should have at least one dimension");
+
+  unsigned groupSize = reassocGroup.size();
+  SmallVector<OpFoldResult> expandedSizes(groupSize);
+
+  uint64_t productOfAllStaticSizes = 1;
+  Optional<unsigned> dynSizeIdx;
+  MemRefType expandShapeType = expandShape.getResultType();
+
+  // Fill up all the statically known sizes.
+  for (unsigned i = 0; i < groupSize; ++i) {
+    uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
+    if (ShapedType::isDynamic(dimSize)) {
+      assert(!dynSizeIdx && "There must be at most one dynamic size per group");
+      dynSizeIdx = i;
+      continue;
+    }
+    productOfAllStaticSizes *= dimSize;
+    expandedSizes[i] = builder.getIndexAttr(dimSize);
+  }
+
+  // Compute the dynamic size using the original size and all the other known
+  // static sizes:
+  // expandSize = origSize / productOfAllStaticSizes.
+  if (dynSizeIdx) {
+    AffineExpr s0 = builder.getAffineSymbolExpr(0);
+    expandedSizes[*dynSizeIdx] = makeComposedFoldedAffineApply(
+        builder, expandShape.getLoc(), s0.floorDiv(productOfAllStaticSizes),
+        origSizes[groupId]);
+  }
+
+  return expandedSizes;
+}
+
+/// Compute the expanded strides of the given \p expandShape for the
+/// \p groupId-th reassociation group.
+/// \p origStrides and \p origSizes hold respectively the strides and sizes
+/// of the source shape as values.
+/// This is used to compute the strides in cases of dynamic shapes and/or
+/// dynamic stride for this reassociation group.
+///
+/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
+///
+/// TODO: Move this utility function directly within ExpandShapeOp. For now,
+/// this is not possible because this function uses the Affine dialect and the
+/// MemRef dialect cannot depend on the Affine dialect.
+SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
+                                             OpBuilder &builder,
+                                             ArrayRef<OpFoldResult> origSizes,
+                                             ArrayRef<OpFoldResult> origStrides,
+                                             unsigned groupId) {
+  SmallVector<int64_t, 2> reassocGroup =
+      expandShape.getReassociationIndices()[groupId];
+  assert(!reassocGroup.empty() &&
+         "Reassociation group should have at least one dimension");
+
+  unsigned groupSize = reassocGroup.size();
+  MemRefType expandShapeType = expandShape.getResultType();
+
+  Optional<int64_t> dynSizeIdx;
+
+  // Fill up the expanded strides, with the information we can deduce from the
+  // resulting shape.
+  uint64_t currentStride = 1;
+  SmallVector<OpFoldResult> expandedStrides(groupSize);
+  for (int i = groupSize - 1; i >= 0; --i) {
+    expandedStrides[i] = builder.getIndexAttr(currentStride);
+    uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
+    if (ShapedType::isDynamic(dimSize)) {
+      assert(!dynSizeIdx && "There must be at most one dynamic size per group");
+      dynSizeIdx = i;
+      continue;
+    }
+
+    currentStride *= dimSize;
+  }
+
+  // Collect the statically known information about the original stride.
+  Value source = expandShape.getSrc();
+  auto sourceType = source.getType().cast<MemRefType>();
+  SmallVector<int64_t> strides;
+  int64_t offset;
+  bool hasKnownStridesAndOffset =
+      succeeded(getStridesAndOffset(sourceType, strides, offset));
+  (void)hasKnownStridesAndOffset;
+  assert(hasKnownStridesAndOffset &&
+         "getStridesAndOffset must work on valid expand_shape");
+
+  OpFoldResult origStride =
+      ShapedType::isDynamicStrideOrOffset(strides[groupId])
+          ? origStrides[groupId]
+          : builder.getIndexAttr(strides[groupId]);
+
+  // Apply the original stride to all the strides.
+  int64_t doneStrideIdx = 0;
+  // If we saw a dynamic dimension, we need to fix-up all the strides up to
+  // that dimension with the dynamic size.
+  if (dynSizeIdx) {
+    int64_t productOfAllStaticSizes = currentStride;
+    assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
+           "We shouldn't be able to change dynamicity");
+    OpFoldResult origSize = origSizes[groupId];
+
+    AffineExpr s0 = builder.getAffineSymbolExpr(0);
+    AffineExpr s1 = builder.getAffineSymbolExpr(1);
+    for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
+      int64_t baseExpandedStride = expandedStrides[doneStrideIdx]
+                                       .get<Attribute>()
+                                       .cast<IntegerAttr>()
+                                       .getInt();
+      expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
+          builder, expandShape.getLoc(),
+          (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
+          {origSize, origStride});
+    }
+  }
+
+  // Now apply the origStride to the remaining dimensions.
+  AffineExpr s0 = builder.getAffineSymbolExpr(0);
+  for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
+    int64_t baseExpandedStride = expandedStrides[doneStrideIdx]
+                                     .get<Attribute>()
+                                     .cast<IntegerAttr>()
+                                     .getInt();
+    expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
+        builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
+  }
+
+  return expandedStrides;
+}
+
+/// Replace `baseBuffer, offset, sizes, strides =
+///              extract_strided_metadata(expand_shape(memref))`
+/// With
+///
+/// \verbatim
+/// baseBuffer, offset, baseSizes, baseStrides =
+///     extract_strided_metadata(memref)
+/// sizes#reassIdx =
+///     baseSizes#reassDim / product(expandShapeSizes#j,
+///                                  for j in group excluding reassIdx)
+/// strides#reassIdx =
+///     baseStrides#reassDim * product(expandShapeSizes#j, for j in
+///                                    reassIdx+1..reassIdx+group.size-1)
+/// \endverbatim
+///
+/// Where reassIdx is a reassociation index for the group at reassDim
+/// and expandShapeSizes#j is either:
+/// - The constant size at dimension j, derived directly from the result type of
+///   the expand_shape op, or
+/// - An affine expression: baseSizes#reassDim / product of all constant sizes
+///   in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
+///   element.)
+///
+/// Notice that `baseBuffer` and `offset` are unchanged.
+///
+/// In other words, get rid of the expand_shape in that expression and
+/// materialize its effects on the sizes and the strides using affine apply.
+struct ExtractStridedMetadataOpExpandShapeFolder
+    : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
+public:
+  using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
+                                PatternRewriter &rewriter) const override {
+    auto expandShape = op.getSource().getDefiningOp<memref::ExpandShapeOp>();
+    if (!expandShape)
+      return failure();
+
+    // Build a plain extract_strided_metadata(memref) from
+    // extract_strided_metadata(expand_shape(memref)).
+    Location origLoc = op.getLoc();
+    IndexType indexType = rewriter.getIndexType();
+    Value source = expandShape.getSrc();
+    auto sourceType = source.getType().cast<MemRefType>();
+    unsigned sourceRank = sourceType.getRank();
+    SmallVector<Type> sizeStrideTypes(sourceRank, indexType);
+
+    auto newExtractStridedMetadata =
+        rewriter.create<memref::ExtractStridedMetadataOp>(
+            origLoc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes,
+            sizeStrideTypes, source);
+
+    // Collect statically known information.
+    SmallVector<int64_t> strides;
+    int64_t offset;
+    bool hasKnownStridesAndOffset =
+        succeeded(getStridesAndOffset(sourceType, strides, offset));
+    (void)hasKnownStridesAndOffset;
+    assert(hasKnownStridesAndOffset &&
+           "getStridesAndOffset must work on valid expand_shape");
+    MemRefType expandShapeType = expandShape.getResultType();
+    unsigned expandShapeRank = expandShapeType.getRank();
+
+    // The result value will start with the base_buffer and offset.
+    unsigned baseIdxInResult = 2;
+    SmallVector<OpFoldResult> results(baseIdxInResult + expandShapeRank * 2);
+    results[0] = newExtractStridedMetadata.getBaseBuffer();
+    results[1] = ShapedType::isDynamicStrideOrOffset(offset)
+                     ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
+                     : rewriter.getIndexAttr(offset);
+
+    // Get the special case of 0-D out of the way.
+    if (sourceRank == 0) {
+      Value constantOne = getValueOrCreateConstantIndexOp(
+          rewriter, origLoc, rewriter.getIndexAttr(1));
+      SmallVector<Value> resultValues(baseIdxInResult + expandShapeRank * 2,
+                                      constantOne);
+      for (unsigned i = 0; i < baseIdxInResult; ++i)
+        resultValues[i] =
+            getValueOrCreateConstantIndexOp(rewriter, origLoc, results[i]);
+      rewriter.replaceOp(op, resultValues);
+      return success();
+    }
+
+    // Compute the expanded strides and sizes from the base strides and sizes.
+    SmallVector<OpFoldResult> origSizes =
+        getAsOpFoldResult(newExtractStridedMetadata.getSizes());
+    SmallVector<OpFoldResult> origStrides =
+        getAsOpFoldResult(newExtractStridedMetadata.getStrides());
+    unsigned idx = 0, endIdx = expandShape.getReassociationIndices().size();
+    for (; idx != endIdx; ++idx) {
+      SmallVector<OpFoldResult> expandedSizes =
+          getExpandedSizes(expandShape, rewriter, origSizes, /*groupId=*/idx);
+      SmallVector<OpFoldResult> expandedStrides = getExpandedStrides(
+          expandShape, rewriter, origSizes, origStrides, /*groupId=*/idx);
+
+      unsigned groupSize = expandShape.getReassociationIndices()[idx].size();
+      const unsigned sizeStartIdx = baseIdxInResult;
+      const unsigned strideStartIdx = sizeStartIdx + expandShapeRank;
+      for (unsigned i = 0; i < groupSize; ++i) {
+        results[sizeStartIdx + i] = expandedSizes[i];
+        results[strideStartIdx + i] = expandedStrides[i];
+      }
+      baseIdxInResult += groupSize;
+    }
+    assert(idx == sourceRank &&
+           "We should have visited all the input dimensions");
+    assert(baseIdxInResult == expandShapeRank + 2 &&
+           "We should have populated all the values");
+    rewriter.replaceOp(
+        op, getValueOrCreateConstantIndexOp(rewriter, origLoc, results));
+    return success();
+  }
+};
 } // namespace
 
 void memref::populateSimplifyExtractStridedMetadataOpPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<ExtractStridedMetadataOpSubviewFolder>(patterns.getContext());
+  patterns.add<ExtractStridedMetadataOpSubviewFolder,
+               ExtractStridedMetadataOpExpandShapeFolder>(
+      patterns.getContext());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
index 0daeb4a23a1f7..3b2b00d2dc6fa 100644
--- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
@@ -280,3 +280,231 @@ func.func @extract_strided_metadata_of_subview_all_dynamic(
   return %base_buffer, %offset, %sizes#0, %sizes#1, %sizes#2, %strides#0, %strides#1, %strides#2 :
     memref<f32>, index, index, index, index, index, index, index
 }
+
+// -----
+
+// Check that we properly simplify extract_strided_metadata of expand_shape
+// into:
+// baseBuffer, baseOffset, baseSizes, baseStrides =
+//     extract_strided_metadata(memref)
+// sizes#reassIdx =
+//     baseSizes#reassDim / product(expandShapeSizes#j,
+//                                  for j in group excluding reassIdx)
+// strides#reassIdx =
+//     baseStrides#reassDim * product(expandShapeSizes#j, for j in
+//                                    reassIdx+1..reassIdx+group.size)
+//
+// Here we have:
+// For the group applying to dim0:
+// size 0 = 3
+// size 1 = 5
+// size 2 = 2
+// stride 0 = baseStrides#0 * 5 * 2
+//          = 4 * 5 * 2
+//          = 40
+// stride 1 = baseStrides#0 * 2
+//          = 4 * 2
+//          = 8
+// stride 2 = baseStrides#0
+//          = 4
+//
+// For the group applying to dim1:
+// size 3 = 2
+// size 4 = 2
+// stride 3 = baseStrides#1 * 2
+//          = 1 * 2
+//          = 2
+// stride 4 = baseStrides#1
+//          = 1
+//
+// Base and offset are unchanged.
+//
+// CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_static
+//  CHECK-SAME: (%[[ARG:.*]]: memref<30x4xi16>)
+//
+//   CHECK-DAG: %[[C40:.*]] = arith.constant 40 : index
+//   CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+//   CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
+//   CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+//   CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//
+//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<30x4xi16> -> memref<i16>, index, index, index, index, index
+//
+//   CHECK: return %[[BASE]], %[[C0]], %[[C3]], %[[C5]], %[[C2]], %[[C2]], %[[C2]], %[[C40]], %[[C8]], %[[C4]], %[[C2]], %[[C1]] : memref<i16>, index, index, index, index, index, index, index, index, index, index, index
+func.func @extract_strided_metadata_of_expand_shape_all_static(
+    %arg : memref<30x4xi16>)
+    -> (memref<i16>, index,
+       index, index, index, index, index,
+       index, index, index, index, index) {
+
+  %expand_shape = memref.expand_shape %arg[[0, 1, 2], [3, 4]] :
+    memref<30x4xi16> into memref<3x5x2x2x2xi16>
+
+  %base, %offset, %sizes:5, %strides:5 = memref.extract_strided_metadata %expand_shape :
+    memref<3x5x2x2x2xi16>
+    -> memref<i16>, index,
+       index, index, index, index, index,
+       index, index, index, index, index
+
+  return %base, %offset,
+    %sizes#0, %sizes#1, %sizes#2, %sizes#3, %sizes#4,
+    %strides#0, %strides#1, %strides#2, %strides#3, %strides#4 :
+      memref<i16>, index,
+      index, index, index, index, index,
+      index, index, index, index, index
+}
+
+// -----
+
+// Check that we properly simplify extract_strided_metadata of expand_shape
+// when dynamic sizes, strides, and offsets are involved.
+// See extract_strided_metadata_of_expand_shape_all_static for an explanation
+// of the expansion.
+//
+// One of the important characteristic of this test is that the dynamic
+// dimensions produced by the expand_shape appear both in the first dimension
+// (for group 1) and the non-first dimension (second dimension for group 2.)
+// The idea is to make sure that:
+// 1. We properly account for dynamic shapes even when the strides are not
+//    affected by them. (When the dynamic dimension is the first one.)
+// 2. We properly compute the strides affected by dynamic shapes. (When the
+//    dynamic dimension is not the first one.)
+//
+// Here we have:
+// For the group applying to dim0:
+// size 0 = baseSizes#0 / (all static sizes in that group)
+//        = baseSizes#0 / (7 * 8 * 9)
+//        = baseSizes#0 / 504
+// size 1 = 7
+// size 2 = 8
+// size 3 = 9
+// stride 0 = baseStrides#0 * 7 * 8 * 9
+//          = baseStrides#0 * 504
+// stride 1 = baseStrides#0 * 8 * 9
+//          = baseStrides#0 * 72
+// stride 2 = baseStrides#0 * 9
+// stride 3 = baseStrides#0
+//
+// For the group applying to dim1:
+// size 4 = 10
+// size 5 = 2
+// size 6 = baseSizes#1 / (all static sizes in that group)
+//        = baseSizes#1 / (10 * 2 * 3)
+//        = baseSizes#1 / 60
+// size 7 = 3
+// stride 4 = baseStrides#1 * size 5 * size 6 * size 7
+//          = baseStrides#1 * 2 * (baseSizes#1 / 60) * 3
+//          = baseStrides#1 * (baseSizes#1 / 60) * 6
+//          and since we know that baseSizes#1 is a multiple of 60:
+//          = baseStrides#1 * (baseSizes#1 / 10)
+// stride 5 = baseStrides#1 * size 6 * size 7
+//          = baseStrides#1 * (baseSizes#1 / 60) * 3
+//          = baseStrides#1 * (baseSizes#1 / 20)
+// stride 6 = baseStrides#1 * size 7
+//          = baseStrides#1 * 3
+// stride 7 = baseStrides#1
+//
+// Base and offset are unchanged.
+//
+//   CHECK-DAG: #[[$DIM0_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 504)>
+//   CHECK-DAG: #[[$DIM6_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 60)>
+//
+//   CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 504)>
+//   CHECK-DAG: #[[$DIM1_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 72)>
+//   CHECK-DAG: #[[$DIM2_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 9)>
+//   CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 10) * s1)>
+//   CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 20) * s1)>
+//   CHECK-DAG: #[[$DIM6_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 3)>
+// CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_dynamic
+//  CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32,
+//
+//   CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
+//   CHECK-DAG: %[[C9:.*]] = arith.constant 9 : index
+//   CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+//   CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index
+//   CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+//
+//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> -> memref<f32>, index, index, index, index, index
+//
+//   CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$DIM0_SIZE_MAP]]()[%[[SIZES]]#0]
+//   CHECK-DAG: %[[DYN_SIZE6:.*]] = affine.apply #[[$DIM6_SIZE_MAP]]()[%[[SIZES]]#1]
+//   CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[STRIDES]]#0]
+//   CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #[[$DIM1_STRIDE_MAP]]()[%[[STRIDES]]#0]
+//   CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #[[$DIM2_STRIDE_MAP]]()[%[[STRIDES]]#0]
+//   CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
+//   CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
+//   CHECK-DAG: %[[DYN_STRIDE6:.*]] = affine.apply #[[$DIM6_STRIDE_MAP]]()[%[[STRIDES]]#1]
+
+//   CHECK: return %[[BASE]], %[[OFFSET]], %[[DYN_SIZE0]], %[[C7]], %[[C8]], %[[C9]], %[[C10]], %[[C2]], %[[DYN_SIZE6]], %[[C3]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1 : memref<f32>, index, index, index, index, index, index, index, index, index, index, index, index, index
+func.func @extract_strided_metadata_of_expand_shape_all_dynamic(
+    %base: memref<?x?xf32, strided<[?,?], offset:?>>,
+    %offset0: index, %offset1: index, %offset2: index,
+    %size0: index, %size1: index, %size2: index,
+    %stride0: index, %stride1: index, %stride2: index)
+    -> (memref<f32>, index,
+       index, index, index, index, index, index, index, index,
+       index, index, index, index, index, index, index, index) {
+
+  %subview = memref.expand_shape %base[[0, 1, 2, 3],[4, 5, 6, 7]] :
+    memref<?x?xf32, strided<[?,?], offset: ?>> into
+      memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
+
+  %base_buffer, %offset, %sizes:8, %strides:8 = memref.extract_strided_metadata %subview :
+    memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>
+    -> memref<f32>, index,
+       index, index, index, index, index, index, index, index,
+       index, index, index, index, index, index, index, index
+
+  return %base_buffer, %offset,
+    %sizes#0, %sizes#1, %sizes#2, %sizes#3, %sizes#4, %sizes#5, %sizes#6, %sizes#7,
+    %strides#0, %strides#1, %strides#2, %strides#3, %strides#4, %strides#5, %strides#6, %strides#7 :
+      memref<f32>, index,
+      index, index, index, index, index, index, index, index,
+      index, index, index, index, index, index, index, index
+}
+
+
+// -----
+
+// Check that we properly handle extract_strided_metadata of expand_shape for
+// 0-D input.
+// The 0-D case is pretty boring:
+// All expanded sizes are 1, likewise for the strides, and we keep the
+// original base and offset.
+// We have still a test for it, because since the input reassociation map
+// of the expand_shape is empty, the handling of such shape hits a corner
+// case.
+// CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_static_0_rank
+//  CHECK-SAME: (%[[ARG:.*]]: memref<i16, strided<[], offset: ?>>)
+//
+//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//
+//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[ARG]] : memref<i16, strided<[], offset: ?>> -> memref<i16>, index
+//
+//   CHECK: return %[[BASE]], %[[OFFSET]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]] : memref<i16>, index, index, index, index, index, index, index, index, index, index, index
+func.func @extract_strided_metadata_of_expand_shape_all_static_0_rank(
+    %arg : memref<i16, strided<[], offset: ?>>)
+    -> (memref<i16>, index,
+       index, index, index, index, index,
+       index, index, index, index, index) {
+
+  %expand_shape = memref.expand_shape %arg[] :
+    memref<i16, strided<[], offset: ?>> into memref<1x1x1x1x1xi16, strided<[1,1,1,1,1], offset: ?>>
+
+  %base, %offset, %sizes:5, %strides:5 = memref.extract_strided_metadata %expand_shape :
+    memref<1x1x1x1x1xi16, strided<[1,1,1,1,1], offset: ?>>
+    -> memref<i16>, index,
+       index, index, index, index, index,
+       index, index, index, index, index
+
+  return %base, %offset,
+    %sizes#0, %sizes#1, %sizes#2, %sizes#3, %sizes#4,
+    %strides#0, %strides#1, %strides#2, %strides#3, %strides#4 :
+      memref<i16>, index,
+      index, index, index, index, index,
+      index, index, index, index, index
+}


        


More information about the Mlir-commits mailing list