[Mlir-commits] [mlir] ec13f6c - [mlir][Linalg] Add verification checks to disallow illegal reshape ops.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 8 10:55:05 PST 2021


Author: MaheshRavishankar
Date: 2021-01-08T10:54:46-08:00
New Revision: ec13f6c3e56952c94909a36a590c679a6a57a046

URL: https://github.com/llvm/llvm-project/commit/ec13f6c3e56952c94909a36a590c679a6a57a046
DIFF: https://github.com/llvm/llvm-project/commit/ec13f6c3e56952c94909a36a590c679a6a57a046.diff

LOG: [mlir][Linalg] Add verification checks to disallow illegal reshape ops.

The existing verification of reshape ops in linalg (linalg.reshape and
linalg.tensor_reshape) allows specification of illegal ops, where
- A dynamic dimension is expanded into multiple dynamic
  dimensions. This is ill-specified.
- A static dimension is expanded into dynamic dimension or viceversa,
- The product of extents of the static dimensions in the expanded type
  doesnt match the static dimension of the collapsed type.
Making all of these illegal. This also implies that some pessimization
in canonicalization due to incomplete semantics of the operation can
be dropped.

Differential Revision: https://reviews.llvm.org/D93724

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 529ba35a0b87..8a97753e1a5c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -820,13 +820,10 @@ template <typename ReshapeOpTy>
 static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
                                   ArrayRef<Attribute> operands) {
   // Fold producer-consumer reshape ops that where the operand type of the
-  // producer is same as the return type of the consumer. This can only be
-  // verified if the shapes in question are static.
+  // producer is same as the return type of the consumer.
   ReshapeOpTy reshapeSrcOp =
       reshapeOp.src().template getDefiningOp<ReshapeOpTy>();
-  if (reshapeSrcOp && reshapeSrcOp.getSrcType().hasStaticShape() &&
-      reshapeOp.getResultType().hasStaticShape() &&
-      reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
+  if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
     return reshapeSrcOp.src();
   // Reshape of a constant can be replaced with a new constant.
   if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
@@ -1030,6 +1027,57 @@ void mlir::linalg::ReshapeOp::build(OpBuilder &b, OperationState &result,
 
 Value mlir::linalg::ReshapeOp::getViewSource() { return src(); }
 
+/// Verify that shapes of the reshaped types using following rules
+/// 1) if a dimension in the collapsed type is static, then the corresponding
+///    dimensions in the expanded shape should be
+///    a) static
+///    b) the product should be same as the collaped shape.
+/// 2) if a dimension in the collaped type is dynamic, one and only one of the
+///    corresponding dimensions in the expanded type should be dynamic. This
+///    rule is only needed with reshape operations that are expanding.
+template <typename OpTy>
+static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
+                                             ShapedType expandedType,
+                                             bool isExpandingReshape) {
+  ArrayRef<int64_t> collapsedShape = collapsedType.getShape();
+  ArrayRef<int64_t> expandedShape = expandedType.getShape();
+  unsigned expandedDimStart = 0;
+  for (auto map : llvm::enumerate(op.getReassociationMaps())) {
+    Optional<int64_t> dynamicDims;
+    int64_t linearizedStaticShape = 1;
+    for (auto dim : llvm::enumerate(expandedShape.slice(
+             expandedDimStart, map.value().getNumResults()))) {
+      if (ShapedType::isDynamic(dim.value())) {
+        if (isExpandingReshape && dynamicDims) {
+          return op->emitOpError("invalid to have a single dimension (")
+                 << map.index() << ") expanded into multiple dynamic dims ("
+                 << expandedDimStart + dynamicDims.getValue() << ","
+                 << expandedDimStart + dim.index() << ")";
+        }
+        dynamicDims = dim.index();
+      } else {
+        linearizedStaticShape *= dim.value();
+      }
+    }
+    if (dynamicDims) {
+      if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
+        return op->emitOpError("expected dimension ")
+               << map.index()
+               << " of collapsed type to be dynamic since one or more of the "
+                  "corresponding dimensions in the expanded type is dynamic";
+      }
+    } else {
+      if (collapsedShape[map.index()] != linearizedStaticShape) {
+        return op->emitOpError("expected dimension ")
+               << map.index() << " of collapsed type to be static value of "
+               << linearizedStaticShape << " ";
+      }
+    }
+    expandedDimStart += map.value().getNumResults();
+  }
+  return success();
+}
+
 // Common verifier for reshape-like types. Fills `expandedType` and
 // `collapsedType` with the proper `src` or `result` type.
 template <typename Op, typename T>
@@ -1073,7 +1121,7 @@ static LogicalResult verifyReshapeLikeTypes(Op op, T &expandedType,
   if (!isReassociationValid(maps, &invalidIdx))
     return op.emitOpError("expected reassociation map #")
            << invalidIdx << " to be valid and contiguous";
-  return success();
+  return verifyReshapeLikeShapes(op, collapsedType, expandedType, !isCollapse);
 }
 
 static LogicalResult verify(ReshapeOp op) {
@@ -1152,8 +1200,6 @@ static LogicalResult verify(TensorReshapeOp op) {
   if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType)))
     return failure();
   auto maps = getAffineMaps(op.reassociation());
-  // TODO: expanding a ? with a non-constant is under-specified. Error
-  // out.
   RankedTensorType expectedType =
       computeTensorReshapeCollapsedType(expandedType, maps);
   if (collapsedType != expectedType)

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index faac64c0d91a..4102a1326b96 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -43,8 +43,6 @@ func @collapsing_tensor_reshapes(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?xf3
 
 // -----
 
-// -----
-
 func @collapsing_tensor_reshapes_to_zero_dim(%arg0 : tensor<1x1x1xf32>)
                                              -> tensor<f32> {
   %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] :
@@ -71,18 +69,18 @@ func @collapsing_memref_reshapes_to_zero_dim(%arg0 : memref<1x1x1xf32>)
 
 // -----
 
-func @expanding_tensor_reshapes(%arg0 : tensor<?x?xf32>) -> tensor<?x?x?x?x?xf32>
+func @expanding_tensor_reshapes(%arg0 : tensor<?x?xf32>) -> tensor<?x6x4x?x5xf32>
 {
   %0 = linalg.tensor_reshape %arg0
          [affine_map<(d0, d1, d2) -> (d0, d1)>,
           affine_map<(d0, d1, d2) -> (d2)>] :
-       tensor<?x?xf32> into tensor<?x?x?xf32>
+       tensor<?x?xf32> into tensor<?x4x?xf32>
   %1 = linalg.tensor_reshape %0
          [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
           affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
           affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] :
-       tensor<?x?x?xf32> into tensor<?x?x?x?x?xf32>
-  return %1 : tensor<?x?x?x?x?xf32>
+       tensor<?x4x?xf32> into tensor<?x6x4x?x5xf32>
+  return %1 : tensor<?x6x4x?x5xf32>
 }
 //   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
 //   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
@@ -113,18 +111,18 @@ func @collapsing_memref_reshapes(%arg0 : memref<?x?x?x?x?xf32>) -> memref<?x?xf3
 
 // -----
 
-func @expanding_memref_reshapes(%arg0 : memref<?x?xf32>) -> memref<?x?x?x?x?xf32>
+func @expanding_memref_reshapes(%arg0 : memref<?x?xf32>) -> memref<?x6x4x5x?xf32>
 {
   %0 = linalg.reshape %arg0
          [affine_map<(d0, d1, d2) -> (d0, d1)>,
           affine_map<(d0, d1, d2) -> (d2)>] :
-       memref<?x?xf32> into memref<?x?x?xf32>
+       memref<?x?xf32> into memref<?x4x?xf32>
   %1 = linalg.reshape %0
          [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
           affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
           affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] :
-       memref<?x?x?xf32> into memref<?x?x?x?x?xf32>
-  return %1 : memref<?x?x?x?x?xf32>
+       memref<?x4x?xf32> into memref<?x6x4x5x?xf32>
+  return %1 : memref<?x6x4x5x?xf32>
 }
 //   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
 //   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
@@ -178,21 +176,20 @@ func @fold_tensor_reshape(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32>
 
 // -----
 
-func @no_fold_tensor_reshape(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+func @fold_tensor_reshape_dynamic(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
 {
   %0 = linalg.tensor_reshape %arg0
          [affine_map<(d0, d1, d2) -> (d0, d1)>,
           affine_map<(d0, d1, d2) -> (d2)>] :
-       tensor<?x?xf32> into tensor<?x?x?xf32>
+       tensor<?x?xf32> into tensor<?x4x?xf32>
   %1 = linalg.tensor_reshape %0
          [affine_map<(d0, d1, d2) -> (d0, d1)>,
           affine_map<(d0, d1, d2) -> (d2)>] :
-       tensor<?x?x?xf32> into tensor<?x?xf32>
+       tensor<?x4x?xf32> into tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
-// CHECK-LABEL: @no_fold_tensor_reshape
-//       CHECK:   linalg.tensor_reshape
-//       CHECK:   linalg.tensor_reshape
+// CHECK-LABEL: @fold_tensor_reshape_dynamic
+//   CHECK-NOT:   linalg.tensor_reshape
 
 // -----
 
@@ -213,21 +210,20 @@ func @fold_memref_reshape(%arg0 : memref<12x4xf32>) -> memref<12x4xf32>
 
 // -----
 
-func @no_fold_memref_reshape(%arg0 : memref<?x?xf32>) -> memref<?x?xf32>
+func @fold_memref_reshape_dynamic(%arg0 : memref<?x?xf32>) -> memref<?x?xf32>
 {
   %0 = linalg.reshape %arg0
          [affine_map<(d0, d1, d2) -> (d0, d1)>,
           affine_map<(d0, d1, d2) -> (d2)>] :
-       memref<?x?xf32> into memref<?x?x?xf32>
+       memref<?x?xf32> into memref<?x4x?xf32>
   %1 = linalg.reshape %0
          [affine_map<(d0, d1, d2) -> (d0, d1)>,
           affine_map<(d0, d1, d2) -> (d2)>] :
-       memref<?x?x?xf32> into memref<?x?xf32>
+       memref<?x4x?xf32> into memref<?x?xf32>
   return %1 : memref<?x?xf32>
 }
-// CHECK-LABEL: @no_fold_memref_reshape
-//       CHECK:   linalg.reshape
-//       CHECK:   linalg.reshape
+// CHECK-LABEL: @fold_memref_reshape_dynamic
+//   CHECK-NOT:   linalg.reshape
 
 // -----
 

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 95a663d19f0d..4359eebebbc1 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -409,3 +409,211 @@ func @matching_inits(%m: memref<?x?xf32>, %t: tensor<?x?xf32>) {
                         -> tensor<?xf32>
   return
 }
+
+
+// -----
+
+func @init_tensor_err(%arg0 : index, %arg1 : index)
+{
+  // expected-error @+1 {{specified type 'tensor<4x?x?x5xf32>' does not match the inferred type 'tensor<4x5x?x?xf32>'}}
+  %1 = linalg.init_tensor [4, 5, %arg0, %arg1] : tensor<4x?x?x5xf32>
+  return
+}
+
+// -----
+
+func @init_tensor_err(%arg0 : index)
+{
+  // expected-error @+1 {{expected 4 sizes values}}
+  %1 = linalg.init_tensor [4, 5, %arg0] : tensor<4x?x?x5xf32>
+  return
+}
+
+// -----
+
+func @init_tensor_err(%arg0 : index)
+{
+  // expected-error @+1 {{expected 2 dynamic sizes values}}
+  %1 = "linalg.init_tensor"(%arg0) {static_sizes = [4, -1, -1, 5]} : (index) -> tensor<4x?x?x5xf32>
+  return
+}
+
+// -----
+
+func @illegal_expanding_reshape_dynamic_tensor
+  (%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?x4x?xf32>
+{
+  // expected-error @+1 {{invalid to have a single dimension (2) expanded into multiple dynamic dims (2,4)}}
+  %0 = linalg.tensor_reshape %arg0
+    [affine_map<(d0, d1, d2, d3, d4) -> (d0)>,
+     affine_map<(d0, d1, d2, d3, d4) -> (d1)>,
+     affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>] :
+    tensor<?x?x?xf32> into tensor<?x?x?x4x?xf32>
+  return %0 : tensor<?x?x?x4x?xf32>
+}
+
+// -----
+
+func @illegal_expanding_reshape_dynamic_memref
+  (%arg0: memref<?x?x?xf32>) -> memref<?x?x?x4x?xf32>
+{
+  // expected-error @+1 {{invalid to have a single dimension (2) expanded into multiple dynamic dims (2,4)}}
+  %0 = linalg.reshape %arg0
+    [affine_map<(d0, d1, d2, d3, d4) -> (d0)>,
+     affine_map<(d0, d1, d2, d3, d4) -> (d1)>,
+     affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>] :
+    memref<?x?x?xf32> into memref<?x?x?x4x?xf32>
+  return %0 : memref<?x?x?x4x?xf32>
+}
+
+// -----
+
+func @illegal_expanding_reshape_static_tensor
+  (%arg0: tensor<2x3x20xf32>) -> tensor<2x3x2x4x5xf32>
+{
+  // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}}
+  %0 = linalg.tensor_reshape %arg0
+    [affine_map<(d0, d1, d2, d3, d4) -> (d0)>,
+     affine_map<(d0, d1, d2, d3, d4) -> (d1)>,
+     affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>] :
+    tensor<2x3x20xf32> into tensor<2x3x2x4x5xf32>
+  return %0 : tensor<2x3x2x4x5xf32>
+}
+
+// -----
+
+func @illegal_collapsing_reshape_static_tensor
+  (%arg0: tensor<2x3x2x4x5xf32>) -> tensor<2x3x20xf32>
+{
+  // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}}
+  %0 = linalg.tensor_reshape %arg0
+    [affine_map<(d0, d1, d2, d3, d4) -> (d0)>,
+     affine_map<(d0, d1, d2, d3, d4) -> (d1)>,
+     affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>] :
+    tensor<2x3x2x4x5xf32> into tensor<2x3x20xf32>
+  return %0 : tensor<2x3x20xf32>
+}
+
+// -----
+
+func @illegal_expanding_reshape_static_memref
+  (%arg0: memref<2x3x20xf32>) -> memref<2x3x2x4x5xf32>
+{
+  // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}}
+  %0 = linalg.reshape %arg0
+    [affine_map<(d0, d1, d2, d3, d4) -> (d0)>,
+     affine_map<(d0, d1, d2, d3, d4) -> (d1)>,
+     affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>] :
+    memref<2x3x20xf32> into memref<2x3x2x4x5xf32>
+  return %0 : memref<2x3x2x4x5xf32>
+}
+
+// -----
+
+func @illegal_collapsing_reshape_static_memref
+  (%arg0: memref<2x3x2x4x5xf32>) -> memref<2x3x20xf32>
+{
+  // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}}
+  %0 = linalg.reshape %arg0
+    [affine_map<(d0, d1, d2, d3, d4) -> (d0)>,
+     affine_map<(d0, d1, d2, d3, d4) -> (d1)>,
+     affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>] :
+    memref<2x3x2x4x5xf32> into memref<2x3x20xf32>
+  return %0 : memref<2x3x20xf32>
+}
+
+// -----
+
+func @illegal_collapsing_reshape_mixed_tensor(%arg0 : tensor<?x?xf32>) -> tensor<?x4x5xf32>
+{
+  // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}}
+  %0 = linalg.tensor_reshape %arg0
+         [affine_map<(d0, d1, d2) -> (d0, d1)>,
+          affine_map<(d0, d1, d2) -> (d2)>] :
+       tensor<?x?xf32> into tensor<?x4x5xf32>
+  return %0 : tensor<?x4x5xf32>
+}
+
+// -----
+
+func @illegal_collapsing_reshape_mixed_tensor_2(%arg0 : tensor<?x?xf32>) -> tensor<?x4x5xf32>
+{
+  // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}}
+  %0 = linalg.tensor_reshape %arg0
+         [affine_map<(d0, d1, d2) -> (d0)>,
+          affine_map<(d0, d1, d2) -> (d1, d2)>] :
+       tensor<?x?xf32> into tensor<?x4x5xf32>
+  return %0 : tensor<?x4x5xf32>
+}
+
+// -----
+
+func @illegal_expanding_reshape_mixed_tensor(%arg0 : tensor<?x4x5xf32>) -> tensor<?x?xf32>
+{
+  // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}}
+  %0 = linalg.tensor_reshape %arg0
+         [affine_map<(d0, d1, d2) -> (d0, d1)>,
+          affine_map<(d0, d1, d2) -> (d2)>] :
+       tensor<?x4x5xf32> into tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func @illegal_expanding_reshape_mixed_tensor_2(%arg0 : tensor<?x4x5xf32>) -> tensor<?x?xf32>
+{
+  // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}}
+  %0 = linalg.tensor_reshape %arg0
+         [affine_map<(d0, d1, d2) -> (d0)>,
+          affine_map<(d0, d1, d2) -> (d1, d2)>] :
+       tensor<?x4x5xf32> into tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// -----
+
+func @illegal_collapsing_reshape_mixed_memref(%arg0 : memref<?x?xf32>) -> memref<?x4x5xf32>
+{
+  // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}}
+  %0 = linalg.reshape %arg0
+         [affine_map<(d0, d1, d2) -> (d0, d1)>,
+          affine_map<(d0, d1, d2) -> (d2)>] :
+       memref<?x?xf32> into memref<?x4x5xf32>
+  return %0 : memref<?x4x5xf32>
+}
+
+// -----
+
+func @illegal_collapsing_reshape_mixed_memref_2(%arg0 : memref<?x?xf32>) -> memref<?x4x5xf32>
+{
+  // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}}
+  %0 = linalg.reshape %arg0
+         [affine_map<(d0, d1, d2) -> (d0)>,
+          affine_map<(d0, d1, d2) -> (d1, d2)>] :
+       memref<?x?xf32> into memref<?x4x5xf32>
+  return %0 : memref<?x4x5xf32>
+}
+
+// -----
+
+func @illegal_expanding_reshape_mixed_memref(%arg0 : memref<?x4x5xf32>) -> memref<?x?xf32>
+{
+  // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}}
+  %0 = linalg.reshape %arg0
+         [affine_map<(d0, d1, d2) -> (d0, d1)>,
+          affine_map<(d0, d1, d2) -> (d2)>] :
+       memref<?x4x5xf32> into memref<?x?xf32>
+  return %0 : memref<?x?xf32>
+}
+
+// -----
+
+func @illegal_expanding_reshape_mixed_memref_2(%arg0 : memref<?x4x5xf32>) -> memref<?x?xf32>
+{
+  // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}}
+  %0 = linalg.reshape %arg0
+         [affine_map<(d0, d1, d2) -> (d0)>,
+          affine_map<(d0, d1, d2) -> (d1, d2)>] :
+       memref<?x4x5xf32> into memref<?x?xf32>
+  return %0 : memref<?x?xf32>
+}

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index c4eb8f8eac67..d0121b0c90c7 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt -split-input-file %s | FileCheck %s
 
 // TODO: Re-enable LLVM lowering test after IndexedGenericOp is lowered.
 //
@@ -621,7 +621,7 @@ func @reshape_dynamic(%arg0: memref<?x?x?xf32>,
     memref<?x?x?xf32> into memref<?x?xf32>
   %r0 = linalg.reshape %0 [affine_map<(i, j, k) -> (i, j)>,
                            affine_map<(i, j, k) -> (k)>] :
-    memref<?x?xf32> into memref<?x?x?xf32>
+    memref<?x?xf32> into memref<?x4x?xf32>
   %1 = linalg.reshape %arg1 [affine_map<(i, j, k) -> (i, j)>,
                              affine_map<(i, j, k) -> (k)>] :
     memref<?x?x?xf32, offset : 0, strides : [?, ?, 1]> into
@@ -629,7 +629,7 @@ func @reshape_dynamic(%arg0: memref<?x?x?xf32>,
   %r1 = linalg.reshape %1 [affine_map<(i, j, k) -> (i, j)>,
                            affine_map<(i, j, k) -> (k)>] :
     memref<?x?xf32, offset : 0, strides : [?, 1]> into
-    memref<?x?x?xf32, offset : 0, strides : [?, ?, 1]>
+    memref<?x4x?xf32, offset : 0, strides : [?, ?, 1]>
   %2 = linalg.reshape %arg2 [affine_map<(i, j, k) -> (i, j)>,
                              affine_map<(i, j, k) -> (k)>] :
     memref<?x?x?xf32, offset : ?, strides : [?, ?, 1]> into
@@ -637,7 +637,7 @@ func @reshape_dynamic(%arg0: memref<?x?x?xf32>,
   %r2 = linalg.reshape %2 [affine_map<(i, j, k) -> (i, j)>,
                            affine_map<(i, j, k) -> (k)>] :
     memref<?x?xf32, offset : ?, strides : [?, 1]> into
-    memref<?x?x?xf32, offset : ?, strides : [?, ?, 1]>
+    memref<?x4x?xf32, offset : ?, strides : [?, ?, 1]>
   return
 }
 
@@ -648,15 +648,15 @@ func @reshape_dynamic(%arg0: memref<?x?x?xf32>,
 //       CHECK:   linalg.reshape {{.*}} [#[[$reshapeD01]], #[[$reshapeD2]]]
 //  CHECK-SAME:     memref<?x?x?xf32> into memref<?x?xf32>
 //       CHECK:   linalg.reshape {{.*}} [#[[$reshapeD01]], #[[$reshapeD2]]]
-//  CHECK-SAME:     memref<?x?xf32> into memref<?x?x?xf32>
+//  CHECK-SAME:     memref<?x?xf32> into memref<?x4x?xf32>
 //       CHECK:   linalg.reshape {{.*}} [#[[$reshapeD01]], #[[$reshapeD2]]]
 //  CHECK-SAME:     memref<?x?x?xf32, #[[$strided3DOFF0]]> into memref<?x?xf32, #[[$strided2DOFF0]]>
 //       CHECK:   linalg.reshape {{.*}} [#[[$reshapeD01]], #[[$reshapeD2]]]
-//  CHECK-SAME:     memref<?x?xf32, #[[$strided2DOFF0]]> into memref<?x?x?xf32, #[[$strided3DOFF0]]>
+//  CHECK-SAME:     memref<?x?xf32, #[[$strided2DOFF0]]> into memref<?x4x?xf32, #[[$strided3DOFF0]]>
 //       CHECK:   linalg.reshape {{.*}} [#[[$reshapeD01]], #[[$reshapeD2]]]
 //  CHECK-SAME:     memref<?x?x?xf32, #[[$strided3D]]> into memref<?x?xf32, #[[$strided2D]]>
 //       CHECK:   linalg.reshape {{.*}} [#[[$reshapeD01]], #[[$reshapeD2]]]
-//  CHECK-SAME:     memref<?x?xf32, #[[$strided2D]]> into memref<?x?x?xf32, #[[$strided3D]]>
+//  CHECK-SAME:     memref<?x?xf32, #[[$strided2D]]> into memref<?x4x?xf32, #[[$strided3D]]>
 
 func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x?xf32>,
                 %ta3: tensor<?x?x?xf32>, %tb3: tensor<?x?x?xf32>, %tc3: tensor<?x?x?xf32>)
@@ -720,27 +720,36 @@ func @init_tensor(%arg0 : index, %arg1 : index)
 
 // -----
 
-func @init_tensor_err(%arg0 : index, %arg1 : index)
+func @legal_collapsing_reshape_dynamic_tensor
+  (%arg0: tensor<?x?x?x4x?xf32>) -> tensor<?x?x?xf32>
 {
-  // expected-error @+1 {{specified type 'tensor<4x?x?x5xf32>' does not match the inferred type 'tensor<4x5x?x?xf32>'}}
-  %1 = linalg.init_tensor [4, 5, %arg0, %arg1] : tensor<4x?x?x5xf32>
-  return
-}
-
-// -----
-
-func @init_tensor_err(%arg0 : index)
-{
-  // expected-error @+1 {{expected 4 sizes values}}
-  %1 = linalg.init_tensor [4, 5, %arg0] : tensor<4x?x?x5xf32>
-  return
+  %0 = linalg.tensor_reshape %arg0
+    [affine_map<(d0, d1, d2, d3, d4) -> (d0)>,
+     affine_map<(d0, d1, d2, d3, d4) -> (d1)>,
+     affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>] :
+    tensor<?x?x?x4x?xf32> into tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
 }
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>
+//     CHECK: func @legal_collapsing_reshape_dynamic_tensor
+//     CHECK:   linalg.tensor_reshape %{{.+}} [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
 
 // -----
 
-func @init_tensor_err(%arg0 : index)
+func @legal_collapsing_reshape_dynamic_memref
+  (%arg0: memref<?x?x?x4x?xf32>) -> memref<?x?x?xf32>
 {
-  // expected-error @+1 {{expected 2 dynamic sizes values}}
-  %1 = "linalg.init_tensor"(%arg0) {static_sizes = [4, -1, -1, 5]} : (index) -> tensor<4x?x?x5xf32>
-  return
-}
+  %0 = linalg.reshape %arg0
+    [affine_map<(d0, d1, d2, d3, d4) -> (d0)>,
+     affine_map<(d0, d1, d2, d3, d4) -> (d1)>,
+     affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>] :
+    memref<?x?x?x4x?xf32> into memref<?x?x?xf32>
+  return %0 : memref<?x?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>
+//     CHECK: func @legal_collapsing_reshape_dynamic_memref
+//     CHECK:   linalg.reshape %{{.+}} [#[[MAP0]], #[[MAP1]], #[[MAP2]]]


        


More information about the Mlir-commits mailing list