[Mlir-commits] [mlir] 2bd7ee4 - [mlir][memref] Fix ExpandShapeOp verifier

Matthias Springer llvmlistbot at llvm.org
Thu Mar 31 01:08:01 PDT 2022


Author: Matthias Springer
Date: 2022-03-31T17:05:52+09:00
New Revision: 2bd7ee45666f1093bafe1a37e1c9ade8aa78ddd2

URL: https://github.com/llvm/llvm-project/commit/2bd7ee45666f1093bafe1a37e1c9ade8aa78ddd2
DIFF: https://github.com/llvm/llvm-project/commit/2bd7ee45666f1093bafe1a37e1c9ade8aa78ddd2.diff

LOG: [mlir][memref] Fix ExpandShapeOp verifier

* Complete rewrite of the verifier.
* CollapseShapeOp verifier will be updated in a subsequent commit.
* Update and expand op documentation.
* Add a new builder that infers the result type based on the source type, result shape and reassociation indices. In essence, only the result layout map is inferred.

Differential Revision: https://reviews.llvm.org/D122641

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/MemRef/invalid.mlir
    mlir/test/Dialect/MemRef/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index c7519f1125f12..ffc6ed2263bc3 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1196,7 +1196,9 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
   
   code commonExtraClassDeclaration = [{
     SmallVector<AffineMap, 4> getReassociationMaps();
+
     SmallVector<ReassociationExprs, 4> getReassociationExprs();
+
     SmallVector<ReassociationIndices, 4> getReassociationIndices() {
       SmallVector<ReassociationIndices, 4> reassociationIndices;
       for (auto attr : reassociation())
@@ -1206,8 +1208,11 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
             })));
       return reassociationIndices;
     };
+
     MemRefType getSrcType() { return src().getType().cast<MemRefType>(); }
+
     MemRefType getResultType() { return result().getType().cast<MemRefType>(); }
+
     Value getViewSource() { return src(); }
   }];
 
@@ -1224,36 +1229,45 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape"> {
   let summary = "operation to produce a memref with a higher rank.";
   let description = [{
     The `memref.expand_shape` op produces a new view with a higher rank whose
-    sizes are a reassociation of the original `view`. Depending on whether or
-    not the reassociated MemRefType is contiguous, the resulting memref may
-    require explicit alloc and copies.
-
-    A reassociation is defined as a continuous grouping of dimensions and is
-    represented with an array of I64ArrayAttr attribute.
+    sizes are a reassociation of the original `view`. The operation is limited
+    to such reassociations, where a dimension is expanded into one or multiple
+    contiguous dimensions. Such reassociations never require additional allocs
+    or copies.
 
-    For now, it is assumed that either:
-      1. a reassociation produces and consumes contiguous MemRefType or,
-      2. the reshape op will be folded into its consumers (by changing the shape
-         of the computations).
-    All other cases are undefined behavior and a reshape op may not lower to
-    LLVM if it cannot be proven statically that it does not require alloc+copy.
-
-    The operand memref type when dimensions can be zero-ranked if the result
-    memref type is statically shaped with all dimensions being unit extent. In
-    such case the reassociation map is empty.
-
-    The verification rule is that the reassociation maps are applied to the
-    result memref with the larger rank to obtain the operand memref with the
-    smaller rank.
+    A reassociation is defined as a grouping of dimensions and is represented
+    with an array of I64ArrayAttr attributes.
 
     Example:
 
     ```mlir
-    // Dimension expansion i -> (i', j') and (k) -> (k')
-    %1 = memref.expand_shape %0 [[0, 1], [2]] :
-      memref<?x?xf32, stride_spec> into memref<?x?x?xf32, stride_spec_2>
+    %r = memref.expand_shape %0 [[0, 1], [2]]
+        : memref<?x?xf32> into memref<?x5x?xf32>
     ```
+
+    At most one dimension of a reassociation group (e.g., [0, 1] above) may be
+    dynamic in the result type. Otherwise, the op would be ambiguous, as it
+    would not be clear how the source dimension is extended.
+
+    If an op can be statically proven to be invalid (e.g, an expansion from
+    `memref<10xf32>` to `memref<2x6xf32>`), it is rejected by the verifier. If
+    it cannot statically be proven invalid (e.g., the full example above; it is
+    unclear whether the first source dimension is divisible by 5), the op is
+    accepted by the verifier. However, if the op is in fact invalid at runtime,
+    the behavior is undefined.
+
+    The source memref can be zero-ranked. In that case, the reassociation
+    indices must be empty and the the result shape may only consist of unit
+    dimensions.
+
+    For simplicity, this op may not be used to cast dynamicity of dimension
+    sizes and/or strides. I.e., if and only if a source dimension is dynamic,
+    there must be a dynamic result dimension in the corresponding reassociation
+    group. Same for strides.
+
+    Note: This op currently assumes that the inner strides are of the
+    source/result layout map are the faster-varying ones.
   }];
+
   let builders = [
     // Builders using ReassociationIndices.
     OpBuilder<(ins "Type":$resultType, "Value":$src,
@@ -1264,6 +1278,8 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape"> {
       $_state.addAttribute("reassociation",
                           getReassociationIndicesAttribute($_builder, reassociation));
     }]>,
+
+    // Builder using ReassociationExprs.
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationExprs>":$reassociation,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
@@ -1271,8 +1287,14 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape"> {
       auto reassociationMaps =
           convertReassociationMapsToIndices($_builder, reassociation);
       build($_builder, $_state, resultType, src, reassociationMaps, attrs);
-    }]>
+    }]>,
+
+    // Builder that infers the result layout map. The result shape must be
+    // specified. Otherwise, the op may be ambiguous.
+    OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
+               "ArrayRef<ReassociationIndices>":$reassociation)>
   ];
+
   let extraClassDeclaration = commonExtraClassDeclaration;
   let hasVerifier = 1;
 }

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 67f166da5ec1a..6b81895a4772a 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1558,9 +1558,106 @@ OpFoldResult ReinterpretCastOp::fold(ArrayRef<Attribute> /*operands*/) {
 // Reassociative reshape ops
 //===----------------------------------------------------------------------===//
 
+/// Helper function that computes a stride based on the size/stride of the
+/// previous dimension.
+///
+/// E.g., memref<20x10x5xf32, offset: 0, strides: [50, 5, 1]>
+///                                                ^^
+///                                        compute this one
+///   prevStride = 5, prevDimSize = 10
+///   nextStride = 5 * 10 = 50
+static int64_t computeNextStride(int64_t prevStride, int64_t prevDimSize) {
+  if (ShapedType::isDynamicStrideOrOffset(prevStride))
+    return ShapedType::kDynamicStrideOrOffset;
+
+  if (ShapedType::isDynamic(prevDimSize))
+    return ShapedType::kDynamicStrideOrOffset;
+
+  return prevStride * prevDimSize;
+}
+
+/// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
+/// result and operand. Layout maps are verified separately.
+///
+/// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
+/// allowed in a reassocation group.
+static LogicalResult
+verifyCollapsedShape(Operation *op, ArrayRef<int64_t> collapsedShape,
+                     ArrayRef<int64_t> expandedShape,
+                     ArrayRef<ReassociationIndices> reassociation,
+                     bool allowMultipleDynamicDimsPerGroup) {
+  // There must be one reassociation group per collapsed dimension.
+  if (collapsedShape.size() != reassociation.size())
+    return op->emitOpError("invalid number of reassociation groups: found ")
+           << reassociation.size() << ", expected " << collapsedShape.size();
+
+  // The next expected expanded dimension index (while iterating over
+  // reassociation indices).
+  int64_t nextDim = 0;
+  for (const auto &it : llvm::enumerate(reassociation)) {
+    ReassociationIndices group = it.value();
+    int64_t collapsedDim = it.index();
+
+    bool foundDynamic = false;
+    for (int64_t expandedDim : group) {
+      if (expandedDim != nextDim++)
+        return op->emitOpError("reassociation indices must be contiguous");
+
+      if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
+        return op->emitOpError("reassociation index ")
+               << expandedDim << " is out of bounds";
+
+      // Check if there are multiple dynamic dims in a reassociation group.
+      if (ShapedType::isDynamic(expandedShape[expandedDim])) {
+        if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
+          return op->emitOpError(
+              "at most one dimension in a reassociation group may be dynamic");
+        foundDynamic = true;
+      }
+    }
+
+    // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
+    if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
+      return op->emitOpError("collapsed dim (")
+             << collapsedDim
+             << ") must be dynamic if and only if reassociation group is "
+                "dynamic";
+
+    // If all dims in the reassociation group are static, the size of the
+    // collapsed dim can be verified.
+    if (!foundDynamic) {
+      int64_t groupSize = 1;
+      for (int64_t expandedDim : group)
+        groupSize *= expandedShape[expandedDim];
+      if (groupSize != collapsedShape[collapsedDim])
+        return op->emitOpError("collapsed dim size (")
+               << collapsedShape[collapsedDim]
+               << ") must equal reassociation group size (" << groupSize << ")";
+    }
+  }
+
+  if (collapsedShape.empty()) {
+    // Rank 0: All expanded dimensions must be 1.
+    for (int64_t d : expandedShape)
+      if (d != 1)
+        return op->emitOpError(
+            "rank 0 memrefs can only be extended/collapsed with/from ones");
+  } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
+    // Rank >= 1: Number of dimensions among all reassociation groups must match
+    // the result memref rank.
+    return op->emitOpError("expanded rank (")
+           << expandedShape.size()
+           << ") inconsistent with number of reassociation indices (" << nextDim
+           << ")";
+  }
+
+  return success();
+}
+
 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
   return getSymbolLessAffineMaps(getReassociationExprs());
 }
+
 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
   return convertReassociationIndicesToExprs(getContext(),
                                             getReassociationIndices());
@@ -1569,6 +1666,7 @@ SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
   return getSymbolLessAffineMaps(getReassociationExprs());
 }
+
 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
   return convertReassociationIndicesToExprs(getContext(),
                                             getReassociationIndices());
@@ -1702,8 +1800,123 @@ static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType,
   return success();
 }
 
+/// Compute the layout map after expanding a given source MemRef type with the
+/// specified reassociation indices.
+static FailureOr<AffineMap>
+computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
+                         ArrayRef<ReassociationIndices> reassociation) {
+  SmallVector<int64_t> srcStrides, resultStrides(resultShape.size(), 0);
+  int64_t srcOffset;
+  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
+    return failure();
+  assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
+
+  // Ensure that inner strides are the fastest-varying ones. Other source layout
+  // maps are currently not supported.
+  int64_t lastStride = 0;
+  for (int64_t s : llvm::reverse(srcStrides)) {
+    if (!ShapedType::isDynamicStrideOrOffset(s)) {
+      if (s < lastStride)
+        return failure();
+      lastStride = s;
+    }
+  }
+
+  // Iterate over all reassociation groups from the back. Example:
+  // strides       = [1000, ?, 2]
+  // source shape  = [20,  10, 5]
+  // result shape  = [ 2, 10,   2, 5,   5]
+  // reassociation = [[0,  1], [2, 3], [4]]
+  for (const auto &it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
+    ReassociationIndices indices = std::get<0>(it);
+    int64_t srcGroupStride = std::get<1>(it);
+
+    // The first result dimension (least significant one) in each reassociation
+    // group has the same stride as the corresponding source dimension. E.g.:
+    // reassociation = [[0, 1], [2, 3], [4]]
+    //                      |       |    |
+    //                      v       v    v
+    //                    1000      ?    2
+    resultStrides[indices.pop_back_val()] = srcGroupStride;
+
+    // Compute the strides for the remaining dims in the reassociation group.
+    for (int64_t resultDim : llvm::reverse(indices)) {
+      // E.g.:
+      // reassociation = [[0, 1], [2, 3], [4]]
+      //                   |
+      //                   v
+      //               1000 * 10 = 10000
+      //
+      // If the previous stride or the previous dimension was dynamic, then this
+      // stride will also be dynamic.
+      resultStrides[resultDim] = computeNextStride(resultStrides[resultDim + 1],
+                                                   resultShape[resultDim + 1]);
+    }
+  }
+
+  return makeStridedLinearLayoutMap(resultStrides, srcOffset,
+                                    srcType.getContext());
+}
+
+static FailureOr<MemRefType>
+computeExpandedType(MemRefType srcType, ArrayRef<int64_t> resultShape,
+                    ArrayRef<ReassociationIndices> reassociation) {
+  if (srcType.getLayout().isIdentity()) {
+    // If the source is contiguous (i.e., no layout map specified), so is the
+    // result.
+    MemRefLayoutAttrInterface layout;
+    return MemRefType::get(resultShape, srcType.getElementType(), layout,
+                           srcType.getMemorySpace());
+  }
+
+  // Source may not be contiguous. Compute the layout map.
+  FailureOr<AffineMap> computedLayout =
+      computeExpandedLayoutMap(srcType, resultShape, reassociation);
+  if (failed(computedLayout))
+    return failure();
+  auto computedType =
+      MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
+                      srcType.getMemorySpaceAsInt());
+  return canonicalizeStridedLayout(computedType);
+}
+
+void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
+                          ArrayRef<int64_t> resultShape, Value src,
+                          ArrayRef<ReassociationIndices> reassociation) {
+  // Only ranked memref source values are supported.
+  auto srcType = src.getType().cast<MemRefType>();
+  FailureOr<MemRefType> resultType =
+      computeExpandedType(srcType, resultShape, reassociation);
+  // Failure of this assertion usually indicates a problem with the source
+  // type, e.g., could not get strides/offset.
+  assert(succeeded(resultType) && "could not compute layout");
+  build(builder, result, *resultType, src, reassociation);
+}
+
 LogicalResult ExpandShapeOp::verify() {
-  return verifyReshapeOp(*this, getResultType(), getSrcType());
+  MemRefType srcType = getSrcType();
+  MemRefType resultType = getResultType();
+
+  // Verify result shape.
+  if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
+                                  resultType.getShape(),
+                                  getReassociationIndices(),
+                                  /*allowMultipleDynamicDimsPerGroup=*/false)))
+    return failure();
+
+  // Compute expected result type (including layout map).
+  FailureOr<MemRefType> expectedResultType = computeExpandedType(
+      srcType, resultType.getShape(), getReassociationIndices());
+  if (failed(expectedResultType))
+    return emitOpError("invalid source layout map");
+
+  // Check actual result type.
+  auto canonicalizedResultType = canonicalizeStridedLayout(resultType);
+  if (*expectedResultType != canonicalizedResultType)
+    return emitOpError("expected expanded type to be ")
+           << *expectedResultType << " but found " << canonicalizedResultType;
+
+  return success();
 }
 
 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,

diff  --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index a714fcb074284..08638234c2259 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -393,8 +393,17 @@ func @copy_
diff erent_eltype(%arg0: memref<2xf32>, %arg1: memref<2xf16>) {
 // -----
 
 func @expand_shape(%arg0: memref<f32>) {
-  // expected-error @+1 {{expected non-zero memref ranks}}
+  // expected-error @+1 {{invalid number of reassociation groups: found 1, expected 0}}
   %0 = memref.expand_shape %arg0 [[0]] : memref<f32> into memref<f32>
+  return
+}
+
+// -----
+
+func @expand_shape(%arg0: memref<f32>) {
+  // expected-error @+1 {{rank 0 memrefs can only be extended/collapsed with/from ones}}
+  %0 = memref.expand_shape %arg0 [] : memref<f32> into memref<1x2xf32>
+  return
 }
 
 // -----
@@ -407,12 +416,22 @@ func @collapse_shape_to_higher_rank(%arg0: memref<f32>) {
 // -----
 
 func @expand_shape_to_smaller_rank(%arg0: memref<1xf32>) {
-  // expected-error @+1 {{expected the type 'memref<f32>' to have higher rank than the type = 'memref<1xf32>'}}
+  // expected-error @+1 {{op reassociation index 0 is out of bounds}}
   %0 = memref.expand_shape %arg0 [[0]] : memref<1xf32> into memref<f32>
 }
 
 // -----
 
+func @expand_shape_invalid_result_layout(
+    %arg0: memref<30x20xf32, offset : 100, strides : [4000, 2]>) {
+  // expected-error @+1 {{expected expanded type to be 'memref<2x15x20xf32, affine_map<(d0, d1, d2) -> (d0 * 60000 + d1 * 4000 + d2 * 2 + 100)>>' but found 'memref<2x15x20xf32, affine_map<(d0, d1, d2) -> (d0 * 5000 + d1 * 4000 + d2 * 2 + 100)>>'}}
+  %0 = memref.expand_shape %arg0 [[0, 1], [2]] :
+      memref<30x20xf32, offset : 100, strides : [4000, 2]>
+      into memref<2x15x20xf32, offset : 100, strides : [5000, 4000, 2]>
+}
+
+// -----
+
 func @collapse_shape(%arg0: memref<?xf32>) {
   // expected-error @+1 {{expected to collapse or expand dims}}
   %0 = memref.collapse_shape %arg0 [[0]] : memref<?xf32> into memref<?xf32>
@@ -446,7 +465,7 @@ func @collapse_shape_wrong_collapsed_type(%arg0: memref<?x?x?xf32>) {
 
 func @expand_shape_illegal_dynamic_memref
   (%arg0: memref<?x?x?xf32>) -> memref<?x?x?x4x?xf32> {
-  // expected-error @+1 {{invalid to have a single dimension (2) expanded into multiple dynamic dims (2,4)}}
+  // expected-error @+1 {{at most one dimension in a reassociation group may be dynamic}}
   %0 = memref.expand_shape %arg0 [[0], [1], [2, 3, 4]]
       : memref<?x?x?xf32> into memref<?x?x?x4x?xf32>
   return %0 : memref<?x?x?x4x?xf32>
@@ -456,7 +475,7 @@ func @expand_shape_illegal_dynamic_memref
 
 func @expand_shape_illegal_static_memref
   (%arg0: memref<2x3x20xf32>) -> memref<2x3x2x4x5xf32> {
-  // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}}
+  // expected-error @+1 {{collapsed dim size (20) must equal reassociation group size (40)}}
   %0 = memref.expand_shape %arg0 [[0], [1], [2, 3, 4]]
       : memref<2x3x20xf32> into memref<2x3x2x4x5xf32>
   return %0 : memref<2x3x2x4x5xf32>
@@ -476,7 +495,7 @@ func @collapse_shape_illegal_static_memref
 
 func @expand_shape_illegal_mixed_memref(%arg0 : memref<?x?xf32>)
     -> memref<?x4x5xf32> {
-  // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}}
+  // expected-error @+1 {{collapsed dim (1) must be dynamic if and only if reassociation group is dynamic}}
   %0 = memref.expand_shape %arg0 [[0, 1], [2]]
       : memref<?x?xf32> into memref<?x4x5xf32>
   return %0 : memref<?x4x5xf32>
@@ -486,7 +505,7 @@ func @expand_shape_illegal_mixed_memref(%arg0 : memref<?x?xf32>)
 
 func @expand_shape_illegal_mixed_memref_2(%arg0 : memref<?x?xf32>)
     -> memref<?x4x5xf32> {
-  // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}}
+  // expected-error @+1 {{collapsed dim (1) must be dynamic if and only if reassociation group is dynamic}}
   %0 = memref.expand_shape %arg0 [[0], [1, 2]]
       : memref<?x?xf32> into memref<?x4x5xf32>
   return %0 : memref<?x4x5xf32>
@@ -494,6 +513,28 @@ func @expand_shape_illegal_mixed_memref_2(%arg0 : memref<?x?xf32>)
 
 // -----
 
+func @expand_shape_unsupported_src_layout(
+    %arg0 : memref<20x2x10x5xf32, offset: 0, strides: [100, 10, 50, 1]>)
+    -> memref<20x2x2x5x5xf32, offset : 0, strides : [100, 10, 250, 50, 1]> {
+  // expected-error @+1 {{invalid source layout map}}
+  %0 = memref.expand_shape %arg0 [[0], [1], [2, 3], [4]] :
+      memref<20x2x10x5xf32, offset: 0, strides: [100, 10, 50, 1]>
+      into memref<20x2x2x5x5xf32, offset : 0, strides : [100, 10, 250, 50, 1]>
+  return %0 : memref<20x2x2x5x5xf32, offset : 0, strides : [100, 10, 250, 50, 1]>
+}
+
+// -----
+
+func @expand_shape_invalid_static_dim_size(%arg0 : memref<?x21xf32>)
+    -> memref<?x4x5xf32> {
+  // expected-error @+1 {{collapsed dim size (21) must equal reassociation group size (20)}}
+  %0 = memref.expand_shape %arg0 [[0], [1, 2]]
+      : memref<?x21xf32> into memref<?x4x5xf32>
+  return %0 : memref<?x4x5xf32>
+}
+
+// -----
+
 func @collapse_shape_illegal_mixed_memref(%arg0 : memref<?x4x5xf32>)
     -> memref<?x?xf32> {
   // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}}

diff  --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 6191cdab02e2e..d001f35580ab0 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -104,107 +104,147 @@ func @memref_alloca_scope() {
   return
 }
 
-func @expand_collapse_shape_static(%arg0: memref<3x4x5xf32>,
-                                   %arg1: tensor<3x4x5xf32>,
-                                   %arg2: tensor<3x?x5xf32>) {
+// CHECK-LABEL: func @expand_collapse_shape_static
+func @expand_collapse_shape_static(
+    %arg0: memref<3x4x5xf32>,
+    %arg1: tensor<3x4x5xf32>,
+    %arg2: tensor<3x?x5xf32>,
+    %arg3: memref<30x20xf32, offset : 100, strides : [4000, 2]>,
+    %arg4: memref<1x5xf32, affine_map<(d0, d1)[s0] -> (d0 * 5 + s0 + d1)>>,
+    %arg5: memref<f32>) {
   // Reshapes that collapse and expand back a contiguous buffer.
+//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
+//  CHECK-SAME:     memref<3x4x5xf32> into memref<12x5xf32>
   %0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
     memref<3x4x5xf32> into memref<12x5xf32>
+
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
+//  CHECK-SAME:     memref<12x5xf32> into memref<3x4x5xf32>
   %r0 = memref.expand_shape %0 [[0, 1], [2]] :
     memref<12x5xf32> into memref<3x4x5xf32>
+
+//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0], [1, 2]]
+//  CHECK-SAME:     memref<3x4x5xf32> into memref<3x20xf32>
   %1 = memref.collapse_shape %arg0 [[0], [1, 2]] :
     memref<3x4x5xf32> into memref<3x20xf32>
+
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0], [1, 2]]
+//  CHECK-SAME:     memref<3x20xf32> into memref<3x4x5xf32>
   %r1 = memref.expand_shape %1 [[0], [1, 2]] :
     memref<3x20xf32> into memref<3x4x5xf32>
+
+//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1, 2]]
+//  CHECK-SAME:     memref<3x4x5xf32> into memref<60xf32>
   %2 = memref.collapse_shape %arg0 [[0, 1, 2]] :
     memref<3x4x5xf32> into memref<60xf32>
+
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1, 2]]
+//  CHECK-SAME:     memref<60xf32> into memref<3x4x5xf32>
   %r2 = memref.expand_shape %2 [[0, 1, 2]] :
-    memref<60xf32> into memref<3x4x5xf32>
+      memref<60xf32> into memref<3x4x5xf32>
+
+//       CHECK:   memref.expand_shape {{.*}} []
+//  CHECK-SAME:     memref<f32> into memref<1x1xf32>
+  %r5 = memref.expand_shape %arg5 [] :
+      memref<f32> into memref<1x1xf32>
+
+// Reshapes with a custom layout map.
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0], [1, 2]]
+  %l0 = memref.expand_shape %arg3 [[0], [1, 2]] :
+      memref<30x20xf32, offset : 100, strides : [4000, 2]>
+      into memref<30x4x5xf32, offset : 100, strides : [4000, 10, 2]>
+
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
+  %l1 = memref.expand_shape %arg3 [[0, 1], [2]] :
+      memref<30x20xf32, offset : 100, strides : [4000, 2]>
+      into memref<2x15x20xf32, offset : 100, strides : [60000, 4000, 2]>
+
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0], [1, 2]]
+  %r4 = memref.expand_shape %arg4 [[0], [1, 2]] :
+      memref<1x5xf32, affine_map<(d0, d1)[s0] -> (d0 * 5 + s0 + d1)>> into
+      memref<1x1x5xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 5 + s0 + d2 + d1 * 5)>>
+
   // Reshapes that expand and collapse back a contiguous buffer with some 1's.
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
+//  CHECK-SAME:     memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
   %3 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] :
     memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
+
+//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
+//  CHECK-SAME:     memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
   %r3 = memref.collapse_shape %3 [[0, 1], [2], [3, 4]] :
     memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
+
   // Reshapes on tensors.
+//       CHECK:   tensor.expand_shape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
   %t0 = tensor.expand_shape %arg1 [[0, 1], [2], [3, 4]] :
     tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
+
+//       CHECK:   tensor.collapse_shape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
   %rt0 = tensor.collapse_shape %t0 [[0, 1], [2], [3, 4]] :
     tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
+
+//       CHECK:   tensor.expand_shape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
   %t1 = tensor.expand_shape %arg2 [[0, 1], [2], [3, 4]] :
     tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
+
+//       CHECK:   tensor.collapse_shape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
   %rt1 = tensor.collapse_shape %t1 [[0], [1, 2], [3, 4]] :
     tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
   return
 }
-// CHECK-LABEL: func @expand_collapse_shape_static
-//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
-//  CHECK-SAME:     memref<3x4x5xf32> into memref<12x5xf32>
-//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
-//  CHECK-SAME:     memref<12x5xf32> into memref<3x4x5xf32>
-//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0], [1, 2]]
-//  CHECK-SAME:     memref<3x4x5xf32> into memref<3x20xf32>
-//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0], [1, 2]]
-//  CHECK-SAME:     memref<3x20xf32> into memref<3x4x5xf32>
-//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1, 2]]
-//  CHECK-SAME:     memref<3x4x5xf32> into memref<60xf32>
-//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1, 2]]
-//  CHECK-SAME:     memref<60xf32> into memref<3x4x5xf32>
-//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
-//  CHECK-SAME:     memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
-//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
-//  CHECK-SAME:     memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
-//
-//       CHECK:   tensor.expand_shape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
-//       CHECK:   tensor.collapse_shape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
-//       CHECK:   tensor.expand_shape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
-//       CHECK:   tensor.collapse_shape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
-
 
+// CHECK-LABEL: func @expand_collapse_shape_dynamic
 func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
          %arg1: memref<?x?x?xf32, offset : 0, strides : [?, ?, 1]>,
          %arg2: memref<?x?x?xf32, offset : ?, strides : [?, ?, 1]>,
          %arg3: memref<?x42xf32, offset : 0, strides : [42, 1]>) {
+//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
+//  CHECK-SAME:     memref<?x?x?xf32> into memref<?x?xf32>
   %0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
     memref<?x?x?xf32> into memref<?x?xf32>
+
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
+//  CHECK-SAME:     memref<?x?xf32> into memref<?x4x?xf32>
   %r0 = memref.expand_shape %0 [[0, 1], [2]] :
     memref<?x?xf32> into memref<?x4x?xf32>
+
+//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
+//  CHECK-SAME:     memref<?x?x?xf32, #[[$strided3DOFF0]]> into memref<?x?xf32, #[[$strided2DOFF0]]>
   %1 = memref.collapse_shape %arg1 [[0, 1], [2]] :
     memref<?x?x?xf32, offset : 0, strides : [?, ?, 1]> into
     memref<?x?xf32, offset : 0, strides : [?, 1]>
+
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
+//  CHECK-SAME:     memref<?x?xf32, #[[$strided2DOFF0]]> into memref<?x4x?xf32, #[[$strided3DOFF0]]>
   %r1 = memref.expand_shape %1 [[0, 1], [2]] :
     memref<?x?xf32, offset : 0, strides : [?, 1]> into
     memref<?x4x?xf32, offset : 0, strides : [?, ?, 1]>
+
+//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
+//  CHECK-SAME:     memref<?x?x?xf32, #[[$strided3D]]> into memref<?x?xf32, #[[$strided2D]]>
   %2 = memref.collapse_shape %arg2 [[0, 1], [2]] :
     memref<?x?x?xf32, offset : ?, strides : [?, ?, 1]> into
     memref<?x?xf32, offset : ?, strides : [?, 1]>
+
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
+//  CHECK-SAME:     memref<?x?xf32, #[[$strided2D]]> into memref<?x4x?xf32, #[[$strided3D]]>
   %r2 = memref.expand_shape %2 [[0, 1], [2]] :
     memref<?x?xf32, offset : ?, strides : [?, 1]> into
     memref<?x4x?xf32, offset : ?, strides : [?, ?, 1]>
+
+//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1]]
+//  CHECK-SAME:     memref<?x42xf32, #[[$strided2D42]]> into memref<?xf32>
   %3 = memref.collapse_shape %arg3 [[0, 1]] :
     memref<?x42xf32, offset : 0, strides : [42, 1]> into
     memref<?xf32, offset : 0, strides : [1]>
+
+//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1]]
+//  CHECK-SAME:     memref<?xf32> into memref<?x42xf32>
   %r3 = memref.expand_shape %3 [[0, 1]] :
-    memref<?xf32, offset : 0, strides : [1]> into
-    memref<?x42xf32, offset : 0, strides : [42, 1]>
+    memref<?xf32, offset : 0, strides : [1]> into memref<?x42xf32>
   return
 }
-// CHECK-LABEL: func @expand_collapse_shape_dynamic
-//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
-//  CHECK-SAME:     memref<?x?x?xf32> into memref<?x?xf32>
-//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
-//  CHECK-SAME:     memref<?x?xf32> into memref<?x4x?xf32>
-//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
-//  CHECK-SAME:     memref<?x?x?xf32, #[[$strided3DOFF0]]> into memref<?x?xf32, #[[$strided2DOFF0]]>
-//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
-//  CHECK-SAME:     memref<?x?xf32, #[[$strided2DOFF0]]> into memref<?x4x?xf32, #[[$strided3DOFF0]]>
-//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
-//  CHECK-SAME:     memref<?x?x?xf32, #[[$strided3D]]> into memref<?x?xf32, #[[$strided2D]]>
-//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
-//  CHECK-SAME:     memref<?x?xf32, #[[$strided2D]]> into memref<?x4x?xf32, #[[$strided3D]]>
-//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1]]
-//  CHECK-SAME:     memref<?x42xf32, #[[$strided2D42]]> into memref<?xf32>
-//       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1]]
-//  CHECK-SAME:     memref<?xf32> into memref<?x42xf32, #[[$strided2D42]]>
 
 func @expand_collapse_shape_zero_dim(%arg0 : memref<1x1xf32>, %arg1 : memref<f32>)
     -> (memref<f32>, memref<1x1xf32>) {


        


More information about the Mlir-commits mailing list