[Mlir-commits] [mlir] fc8f465 - [mlir][MemRef] Allow transposed layouts in ExpandShapeOp.

Nicolas Vasilache llvmlistbot at llvm.org
Wed Apr 6 01:19:34 PDT 2022


Author: Nicolas Vasilache
Date: 2022-04-06T04:19:30-04:00
New Revision: fc8f465a0008826cb7431eb5684861477998662c

URL: https://github.com/llvm/llvm-project/commit/fc8f465a0008826cb7431eb5684861477998662c
DIFF: https://github.com/llvm/llvm-project/commit/fc8f465a0008826cb7431eb5684861477998662c.diff

LOG: [mlir][MemRef] Allow transposed layouts in ExpandShapeOp.

https://reviews.llvm.org/D122641 introduced fixes to the ExpandShapeOp verifier
but also introduced an artificial layout limitation that prevents the consideration of transposed layouts.

This revision fixes the omissions and reimplements the logic using saturated arithmetic which is more
idiomatic and avoids leaking internal implementation details.

Tests cases are added for transposed layouts.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D122845

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/MemRef/invalid.mlir
    mlir/test/Dialect/MemRef/ops.mlir
    mlir/test/Dialect/Tensor/bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 5a8bd2b8dd551..094e6611cd101 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -26,6 +26,49 @@
 using namespace mlir;
 using namespace mlir::memref;
 
+namespace {
+/// Idiomatic saturated operations on offsets, sizes and strides.
+namespace saturated_arith {
+struct Wrapper {
+  static Wrapper stride(int64_t v) {
+    return (ShapedType::isDynamicStrideOrOffset(v)) ? Wrapper{true, 0}
+                                                    : Wrapper{false, v};
+  }
+  static Wrapper offset(int64_t v) {
+    return (ShapedType::isDynamicStrideOrOffset(v)) ? Wrapper{true, 0}
+                                                    : Wrapper{false, v};
+  }
+  static Wrapper size(int64_t v) {
+    return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
+  }
+  int64_t asOffset() {
+    return saturated ? ShapedType::kDynamicStrideOrOffset : v;
+  }
+  int64_t asSize() { return saturated ? ShapedType::kDynamicSize : v; }
+  int64_t asStride() {
+    return saturated ? ShapedType::kDynamicStrideOrOffset : v;
+  }
+  bool operator==(Wrapper other) {
+    return (saturated && other.saturated) ||
+           (!saturated && !other.saturated && v == other.v);
+  }
+  bool operator!=(Wrapper other) { return !(*this == other); }
+  Wrapper operator+(Wrapper other) {
+    if (saturated || other.saturated)
+      return Wrapper{true, 0};
+    return Wrapper{false, other.v + v};
+  }
+  Wrapper operator*(Wrapper other) {
+    if (saturated || other.saturated)
+      return Wrapper{true, 0};
+    return Wrapper{false, other.v * v};
+  }
+  bool saturated;
+  int64_t v;
+};
+} // namespace saturated_arith
+} // namespace
+
 /// Materialize a single constant operation from a given attribute value with
 /// the desired resultant type.
 Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
@@ -1558,24 +1601,6 @@ OpFoldResult ReinterpretCastOp::fold(ArrayRef<Attribute> /*operands*/) {
 // Reassociative reshape ops
 //===----------------------------------------------------------------------===//
 
-/// Helper function that computes a stride based on the size/stride of the
-/// previous dimension.
-///
-/// E.g., memref<20x10x5xf32, offset: 0, strides: [50, 5, 1]>
-///                                                ^^
-///                                        compute this one
-///   prevStride = 5, prevDimSize = 10
-///   nextStride = 5 * 10 = 50
-static int64_t computeNextStride(int64_t prevStride, int64_t prevDimSize) {
-  if (ShapedType::isDynamicStrideOrOffset(prevStride))
-    return ShapedType::kDynamicStrideOrOffset;
-
-  if (ShapedType::isDynamic(prevDimSize))
-    return ShapedType::kDynamicStrideOrOffset;
-
-  return prevStride * prevDimSize;
-}
-
 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
 /// result and operand. Layout maps are verified separately.
 ///
@@ -1677,57 +1702,41 @@ SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
 static FailureOr<AffineMap>
 computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
                          ArrayRef<ReassociationIndices> reassociation) {
-  SmallVector<int64_t> srcStrides, resultStrides(resultShape.size(), 0);
   int64_t srcOffset;
+  SmallVector<int64_t> srcStrides;
   if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
     return failure();
   assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
 
-  // Ensure that inner strides are the fastest-varying ones. Other source layout
-  // maps are currently not supported.
-  int64_t lastStride = 0;
-  for (int64_t s : llvm::reverse(srcStrides)) {
-    if (!ShapedType::isDynamicStrideOrOffset(s)) {
-      if (s < lastStride)
-        return failure();
-      lastStride = s;
-    }
-  }
-
-  // Iterate over all reassociation groups from the back. Example:
-  // strides       = [1000, ?, 2]
-  // source shape  = [20,  10, 5]
-  // result shape  = [ 2, 10,   2, 5,   5]
-  // reassociation = [[0,  1], [2, 3], [4]]
-  for (const auto &it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
-    ReassociationIndices indices = std::get<0>(it);
-    int64_t srcGroupStride = std::get<1>(it);
-
-    // The first result dimension (least significant one) in each reassociation
-    // group has the same stride as the corresponding source dimension. E.g.:
-    // reassociation = [[0, 1], [2, 3], [4]]
-    //                      |       |    |
-    //                      v       v    v
-    //                    1000      ?    2
-    resultStrides[indices.pop_back_val()] = srcGroupStride;
-
-    // Compute the strides for the remaining dims in the reassociation group.
-    for (int64_t resultDim : llvm::reverse(indices)) {
-      // E.g.:
-      // reassociation = [[0, 1], [2, 3], [4]]
-      //                   |
-      //                   v
-      //               1000 * 10 = 10000
-      //
-      // If the previous stride or the previous dimension was dynamic, then this
-      // stride will also be dynamic.
-      resultStrides[resultDim] = computeNextStride(resultStrides[resultDim + 1],
-                                                   resultShape[resultDim + 1]);
+  // 1-1 mapping between srcStrides and reassociation packs.
+  // Each srcStride starts with the given value and gets expanded according to
+  // the proper entries in resultShape.
+  // Example:
+  //   srcStrides     =                   [10000,  1 ,    100   ],
+  //   reassociations =                   [  [0], [1], [2, 3, 4]],
+  //   resultSizes    = [2, 5, 4, 3, 2] = [  [2], [5], [4, 3, 2]]
+  //     -> For the purpose of stride calculation, the useful sizes are:
+  //                    [x, x, x, 3, 2] = [  [x], [x], [x, 3, 2]].
+  //   resultStrides = [10000, 1, 600, 200, 100]
+  // Note that a stride does not get expanded along the first entry of each
+  // shape pack.
+  SmallVector<int64_t> reverseResultStrides;
+  reverseResultStrides.reserve(resultShape.size());
+  unsigned shapeIndex = resultShape.size() - 1;
+  for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
+    ReassociationIndices reassoc = std::get<0>(it);
+    int64_t currentStrideToExpand = std::get<1>(it);
+    for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
+      using saturated_arith::Wrapper;
+      reverseResultStrides.push_back(currentStrideToExpand);
+      currentStrideToExpand = (Wrapper::stride(currentStrideToExpand) *
+                               Wrapper::size(resultShape[shapeIndex--]))
+                                  .asStride();
     }
   }
-
-  return makeStridedLinearLayoutMap(resultStrides, srcOffset,
-                                    srcType.getContext());
+  return makeStridedLinearLayoutMap(
+      llvm::to_vector<8>(llvm::reverse(reverseResultStrides)), srcOffset,
+      srcType.getContext());
 }
 
 static FailureOr<MemRefType>
@@ -1804,94 +1813,52 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
 /// not possible to check this by inspecting a MemRefType in the general case.
 /// But it is assumed. If this is not the case, the behavior is undefined.
 static FailureOr<AffineMap>
-computeCollapsedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
+computeCollapsedLayoutMap(MemRefType srcType,
                           ArrayRef<ReassociationIndices> reassociation) {
-  SmallVector<int64_t> srcStrides, resultStrides;
   int64_t srcOffset;
+  SmallVector<int64_t> srcStrides;
+  auto srcShape = srcType.getShape();
   if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
     return failure();
-  assert(resultShape.size() == reassociation.size() && "invalid reassociation");
-
-  // Iterate over all reassociation groups from the back. Example:
-  // source shape   = [20, ?,   5, 10, 2]
-  // source strides = [ ?, ?, 800, 80, 4]
-  // reassociation  = [[0, 1], [2, 3], [4]]
-  // result shape   = [     ?,     50,   2]
-  //
-  // Note: The result shape is not needed in this computation. It is just used
-  // check that the size of the reassociation is correct.
-  for (ReassociationIndices group : llvm::reverse(reassociation)) {
-    // A result dim has the same stride as the first dimension (least
-    // significant one) in the corresponding reassociation group. E.g.:
-    // reassociation  = [[0, 1], [2, 3], [4]]
-    //                       |       |    |
-    //                       v       v    v
-    //                       ?      80    4
-    int64_t resultStride = srcStrides[group.pop_back_val()];
-
-    // The following is just a best-effort check for non-contiguous source
-    // strides within a reassociation group. E.g.:
-    // reassociation  = [[0, 1], [2, 3], [4]]
-    //                           ^^^^^^
-    // Iteratively compute the next stride within the reassociation group
-    // one-by-one. Start with the stride computed above. E.g.:
-    // reassociation  = [[0, 1], [2, 3], [4]]
-    //                               |
-    //                               v
-    //                  nextStride = 80
-    int64_t nextStride = resultStride;
-    for (int64_t nextDim : llvm::reverse(group)) {
-      // Next expected stride is previous stride multiplied by dim size, e.g.:
-      // reassociation  = [[0, 1], [2, 3], [4]]
-      //                            |
-      //                            v
-      //    nextStride = 80 * 10 = 800
-      nextStride =
-          computeNextStride(nextStride, srcType.getDimSize(nextDim + 1));
-
-      // Ensure that the source actually has this stride value. E.g.:
-      // source strides = [ ?, ?, 800, 80, 4]
-      //                           |
-      //                           v
-      //                  same stride, OK
-      // If strides are dynamic, we cannot verify anything statically.
-      if (!ShapedType::isDynamicStrideOrOffset(srcStrides[nextDim]) &&
-          !ShapedType::isDynamicStrideOrOffset(nextStride) &&
-          srcStrides[nextDim] != nextStride) {
-        // Attempting to collapse non-contiguous dimensions. This is forbidden.
-        // Note: This check does not handle cases where strides and dimension
-        // sizes are dynamic. Such dims could still turn out to be non-
-        // contiguous at runtime. This check is only a best effort to catch
-        // illegal collapses at verification time.
+
+  // The result strides are exactly the strides of the last entry of each
+  // reassociation.
+  SmallVector<int64_t> resultStrides;
+  resultStrides.reserve(reassociation.size());
+  for (ReassociationIndices reassoc : reassociation)
+    resultStrides.push_back(srcStrides[reassoc.back()]);
+
+  // Validate that each reassociation group is contiguous.
+  unsigned resultStrideIndex = resultStrides.size() - 1;
+  for (ReassociationIndices reassoc : llvm::reverse(reassociation)) {
+    auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
+    using saturated_arith::Wrapper;
+    auto stride = Wrapper::stride(resultStrides[resultStrideIndex--]);
+    for (int64_t idx : llvm::reverse(trailingReassocs)) {
+      stride = stride * Wrapper::size(srcShape[idx]);
+      // Both are either static strides of the same value, or both are dynamic.
+      // The dynamic case is best effort atm : we can't check it statically.
+      // One exception to the dynamic check is when the srcShape is `1`, in
+      // which case it can never produce a non-contiguity.
+      if (stride != Wrapper::stride(srcStrides[idx - 1]) && srcShape[idx] != 1)
         return failure();
-      }
     }
-
-    resultStrides.push_back(resultStride);
   }
-
-  return makeStridedLinearLayoutMap(
-      llvm::to_vector<8>(llvm::reverse(resultStrides)), srcOffset,
-      srcType.getContext());
+  return makeStridedLinearLayoutMap(resultStrides, srcOffset,
+                                    srcType.getContext());
 }
 
 static MemRefType
 computeCollapsedType(MemRefType srcType,
                      ArrayRef<ReassociationIndices> reassociation) {
   SmallVector<int64_t> resultShape;
+  resultShape.reserve(reassociation.size());
   for (const ReassociationIndices &group : reassociation) {
-    int64_t groupSize = 1;
-    for (int64_t srcDim : group) {
-      if (srcType.isDynamicDim(srcDim)) {
-        // Source dim is dynamic, so the collapsed dim is also dynamic.
-        groupSize = ShapedType::kDynamicSize;
-        break;
-      }
-
-      groupSize *= srcType.getDimSize(srcDim);
-    }
-
-    resultShape.push_back(groupSize);
+    using saturated_arith::Wrapper;
+    auto groupSize = Wrapper::size(1);
+    for (int64_t srcDim : group)
+      groupSize = groupSize * Wrapper::size(srcType.getDimSize(srcDim));
+    resultShape.push_back(groupSize.asSize());
   }
 
   if (srcType.getLayout().isIdentity()) {
@@ -1906,7 +1873,7 @@ computeCollapsedType(MemRefType srcType,
   // Note: Dimensions that are collapsed into a single dim are assumed to be
   // contiguous.
   FailureOr<AffineMap> computedLayout =
-      computeCollapsedLayoutMap(srcType, resultShape, reassociation);
+      computeCollapsedLayoutMap(srcType, reassociation);
   assert(succeeded(computedLayout) &&
          "invalid source layout map or collapsing non-contiguous dims");
   auto computedType =
@@ -1948,8 +1915,8 @@ LogicalResult CollapseShapeOp::verify() {
     // Source may not be fully contiguous. Compute the layout map.
     // Note: Dimensions that are collapsed into a single dim are assumed to be
     // contiguous.
-    FailureOr<AffineMap> computedLayout = computeCollapsedLayoutMap(
-        srcType, resultType.getShape(), getReassociationIndices());
+    FailureOr<AffineMap> computedLayout =
+        computeCollapsedLayoutMap(srcType, getReassociationIndices());
     if (failed(computedLayout))
       return emitOpError(
           "invalid source layout map or collapsing non-contiguous dims");
@@ -2066,29 +2033,6 @@ LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
 // SubViewOp
 //===----------------------------------------------------------------------===//
 
-namespace {
-/// Helpers to write more idiomatic operations.
-namespace saturated_arith {
-struct Wrapper {
-  explicit Wrapper(int64_t v) : v(v) {}
-  operator int64_t() { return v; }
-  int64_t v;
-};
-Wrapper operator+(Wrapper a, int64_t b) {
-  if (ShapedType::isDynamicStrideOrOffset(a) ||
-      ShapedType::isDynamicStrideOrOffset(b))
-    return Wrapper(ShapedType::kDynamicStrideOrOffset);
-  return Wrapper(a.v + b);
-}
-Wrapper operator*(Wrapper a, int64_t b) {
-  if (ShapedType::isDynamicStrideOrOffset(a) ||
-      ShapedType::isDynamicStrideOrOffset(b))
-    return Wrapper(ShapedType::kDynamicStrideOrOffset);
-  return Wrapper(a.v * b);
-}
-} // namespace saturated_arith
-} // namespace
-
 /// A subview result type can be fully inferred from the source type and the
 /// static representation of offsets, sizes and strides. Special sentinels
 /// encode the dynamic case.
@@ -2114,8 +2058,11 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
   int64_t targetOffset = sourceOffset;
   for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
     auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
-    using namespace saturated_arith;
-    targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride;
+    using saturated_arith::Wrapper;
+    targetOffset =
+        (Wrapper::offset(targetOffset) +
+         Wrapper::offset(staticOffset) * Wrapper::stride(targetStride))
+            .asOffset();
   }
 
   // Compute target stride whose value is:
@@ -2124,8 +2071,10 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
   targetStrides.reserve(staticOffsets.size());
   for (auto it : llvm::zip(sourceStrides, staticStrides)) {
     auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
-    using namespace saturated_arith;
-    targetStrides.push_back(Wrapper(sourceStride) * staticStride);
+    using saturated_arith::Wrapper;
+    targetStrides.push_back(
+        (Wrapper::stride(sourceStride) * Wrapper::stride(staticStride))
+            .asStride());
   }
 
   // The type is now known.
@@ -2305,8 +2254,8 @@ void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
 /// For ViewLikeOpInterface.
 Value SubViewOp::getViewSource() { return source(); }
 
-/// Return true if t1 and t2 have equal offsets (both dynamic or of same static
-/// value).
+/// Return true if t1 and t2 have equal offsets (both dynamic or of same
+/// static value).
 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
   AffineExpr t1Offset, t2Offset;
   SmallVector<AffineExpr> t1Strides, t2Strides;
@@ -2431,12 +2380,12 @@ SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
   return res;
 }
 
-/// Compute the canonical result type of a SubViewOp. Call `inferResultType` to
-/// deduce the result type for the given `sourceType`. Additionally, reduce the
-/// rank of the inferred result type if `currentResultType` is lower rank than
-/// `currentSourceType`. Use this signature if `sourceType` is updated together
-/// with the result type. In this case, it is important to compute the dropped
-/// dimensions using `currentSourceType` whose strides align with
+/// Compute the canonical result type of a SubViewOp. Call `inferResultType`
+/// to deduce the result type for the given `sourceType`. Additionally, reduce
+/// the rank of the inferred result type if `currentResultType` is lower rank
+/// than `currentSourceType`. Use this signature if `sourceType` is updated
+/// together with the result type. In this case, it is important to compute
+/// the dropped dimensions using `currentSourceType` whose strides align with
 /// `currentResultType`.
 static MemRefType getCanonicalSubViewResultType(
     MemRefType currentResultType, MemRefType currentSourceType,
@@ -2464,9 +2413,9 @@ static MemRefType getCanonicalSubViewResultType(
                          nonRankReducedType.getMemorySpace());
 }
 
-/// Compute the canonical result type of a SubViewOp. Call `inferResultType` to
-/// deduce the result type. Additionally, reduce the rank of the inferred result
-/// type if `currentResultType` is lower rank than `sourceType`.
+/// Compute the canonical result type of a SubViewOp. Call `inferResultType`
+/// to deduce the result type. Additionally, reduce the rank of the inferred
+/// result type if `currentResultType` is lower rank than `sourceType`.
 static MemRefType getCanonicalSubViewResultType(
     MemRefType currentResultType, MemRefType sourceType,
     ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
@@ -2478,8 +2427,8 @@ static MemRefType getCanonicalSubViewResultType(
 
 /// Helper method to check if a `subview` operation is trivially a no-op. This
 /// is the case if the all offsets are zero, all strides are 1, and the source
-/// shape is same as the size of the subview. In such cases, the subview can be
-/// folded into its source.
+/// shape is same as the size of the subview. In such cases, the subview can
+/// be folded into its source.
 static bool isTrivialSubViewOp(SubViewOp subViewOp) {
   if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
     return false;
@@ -2536,7 +2485,8 @@ class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
 
   LogicalResult matchAndRewrite(SubViewOp subViewOp,
                                 PatternRewriter &rewriter) const override {
-    // Any constant operand, just return to let SubViewOpConstantFolder kick in.
+    // Any constant operand, just return to let SubViewOpConstantFolder kick
+    // in.
     if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
           return matchPattern(operand, matchConstantIndex());
         }))
@@ -2549,10 +2499,10 @@ class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
     if (!CastOp::canFoldIntoConsumerOp(castOp))
       return failure();
 
-    // Compute the SubViewOp result type after folding the MemRefCastOp. Use the
-    // MemRefCastOp source operand type to infer the result type and the current
-    // SubViewOp source operand type to compute the dropped dimensions if the
-    // operation is rank-reducing.
+    // Compute the SubViewOp result type after folding the MemRefCastOp. Use
+    // the MemRefCastOp source operand type to infer the result type and the
+    // current SubViewOp source operand type to compute the dropped dimensions
+    // if the operation is rank-reducing.
     auto resultType = getCanonicalSubViewResultType(
         subViewOp.getType(), subViewOp.getSourceType(),
         castOp.source().getType().cast<MemRefType>(),
@@ -2571,8 +2521,8 @@ class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
   }
 };
 
-/// Canonicalize subview ops that are no-ops. When the source shape is not same
-/// as a result shape due to use of `affine_map`.
+/// Canonicalize subview ops that are no-ops. When the source shape is not
+/// same as a result shape due to use of `affine_map`.
 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
 public:
   using OpRewritePattern<SubViewOp>::OpRewritePattern;

diff  --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 21ebd57564a39..03ab9eb5db9ad 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -517,18 +517,6 @@ func @expand_shape_illegal_mixed_memref_2(%arg0 : memref<?x?xf32>)
 
 // -----
 
-func @expand_shape_unsupported_src_layout(
-    %arg0 : memref<20x2x10x5xf32, offset: 0, strides: [100, 10, 50, 1]>)
-    -> memref<20x2x2x5x5xf32, offset : 0, strides : [100, 10, 250, 50, 1]> {
-  // expected-error @+1 {{invalid source layout map}}
-  %0 = memref.expand_shape %arg0 [[0], [1], [2, 3], [4]] :
-      memref<20x2x10x5xf32, offset: 0, strides: [100, 10, 50, 1]>
-      into memref<20x2x2x5x5xf32, offset : 0, strides : [100, 10, 250, 50, 1]>
-  return %0 : memref<20x2x2x5x5xf32, offset : 0, strides : [100, 10, 250, 50, 1]>
-}
-
-// -----
-
 func @expand_shape_invalid_static_dim_size(%arg0 : memref<?x21xf32>)
     -> memref<?x4x5xf32> {
   // expected-error @+1 {{collapsed dim size (21) must equal reassociation group size (20)}}

diff  --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 8f49b6e14d7d2..bd2a6fc7489c9 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -281,6 +281,29 @@ func @collapse_shape_to_dynamic
 
 // -----
 
+// CHECK-LABEL: func @expand_collapse_shape_transposed_layout
+func @expand_collapse_shape_transposed_layout(
+    %m0: memref<?x?xf32, offset : 0, strides : [1, 10]>,
+    %m1: memref<4x5x6xf32, offset : 0, strides : [1, ?, 1000]>) {
+
+  %r0 = memref.expand_shape %m0 [[0], [1, 2]] :
+    memref<?x?xf32, offset : 0, strides : [1, 10]> into
+    memref<?x?x5xf32, offset : 0, strides : [1, 50, 10]>
+  %rr0 = memref.collapse_shape %r0 [[0], [1, 2]] :
+    memref<?x?x5xf32, offset : 0, strides : [1, 50, 10]> into
+    memref<?x?xf32, offset : 0, strides : [1, 10]>
+
+  %r1 = memref.expand_shape %m1 [[0, 1], [2], [3, 4]] :
+    memref<4x5x6xf32, offset : 0, strides : [1, ?, 1000]> into 
+    memref<2x2x5x2x3xf32, offset : 0, strides : [2, 1, ?, 3000, 1000]>
+  %rr1 = memref.collapse_shape %r1 [[0, 1], [2], [3, 4]] :
+    memref<2x2x5x2x3xf32, offset : 0, strides : [2, 1, ?, 3000, 1000]> into
+    memref<4x5x6xf32, offset : 0, strides : [1, ?, 1000]>
+  return
+}
+
+// -----
+
 func @rank(%t : memref<4x4x?xf32>) {
   // CHECK: %{{.*}} = memref.rank %{{.*}} : memref<4x4x?xf32>
   %0 = "memref.rank"(%t) : (memref<4x4x?xf32>) -> index

diff  --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 2178d9d3fa4fc..b29ec0201f09f 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -5,6 +5,7 @@
 // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 140 + d1 * 20 + d2 * 5 + d3 + s0)>
 // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0 + 1)>
 // CHECK-DAG: #[[$MAP4:.*]] = affine_map<() -> (1)>
+// CHECK-DAG: #[[$MAP5:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
 
 // CHECK-LABEL:   func @dim(
 // CHECK-SAME:              %[[TENSOR:.*]]: tensor<f32>,
@@ -337,6 +338,17 @@ func @tensor.expand_shape_of_slice(
   return %1 : tensor<?x7x2x5xf32>
 }
 
+// CHECK-LABEL: func @tensor.expand_shape_of_slice2(
+//  CHECK-SAME:     %[[t1:.*]]: tensor<1x2xf32>
+func @tensor.expand_shape_of_slice2(%t1: tensor<1x2xf32>) -> tensor<1xf32> {
+  // CHECK: memref.subview {{.*}} : memref<1x2xf32> to memref<1x1xf32, #[[$MAP5]]>
+  %0 = tensor.extract_slice %t1[0, 0][1, 1][1, 1] : tensor<1x2xf32> to tensor<1x1xf32>
+  // CHECK: memref.collapse_shape %{{.*}} [
+  // CHECK-SAME: [0, 1]] : memref<1x1xf32, #[[$MAP5]]> into memref<1xf32>
+  %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x1xf32> into tensor<1xf32>
+  return %1 : tensor<1xf32>
+}
+
 // CHECK-LABEL: func @tensor.collapse_shape(
 //  CHECK-SAME:     %[[t1:.*]]: tensor<2x?x?xf32>
 func @tensor.collapse_shape(%t1: tensor<2x?x?xf32>) -> tensor<?x?xf32> {


        


More information about the Mlir-commits mailing list