[Mlir-commits] [mlir] 5440d0a - [mlir][Linalg] Add folders and canonicalizers for

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 12 23:04:01 PDT 2020


Author: MaheshRavishankar
Date: 2020-05-12T23:03:26-07:00
New Revision: 5440d0a12d7f6f7f7689a7a733de7cc622605270

URL: https://github.com/llvm/llvm-project/commit/5440d0a12d7f6f7f7689a7a733de7cc622605270
DIFF: https://github.com/llvm/llvm-project/commit/5440d0a12d7f6f7f7689a7a733de7cc622605270.diff

LOG: [mlir][Linalg] Add folders and canonicalizers for
linalg.reshape/linalg.tensor_reshape operations.

Differential Revision: https://reviews.llvm.org/D79765

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 874bda002a59..1615957ff0c3 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -77,6 +77,10 @@ class Linalg_ReshapeLikeOp<string mnemonic> :
 
   code commonExtraClassDeclaration = [{
     static StringRef getReassociationAttrName() { return "reassociation"; }
+    SmallVector<AffineMap, 4> getReassociationMaps() {
+      return llvm::to_vector<4>(llvm::map_range(reassociation(), [
+      ](Attribute a) { return a.cast<AffineMapAttr>().getValue(); }));
+    }
   }];
   let assemblyFormat = [{
     $src $reassociation attr-dict `:` type($src) `into` type(results)
@@ -137,6 +141,7 @@ def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape">,
     MemRefType getResultType() { return result().getType().cast<MemRefType>(); }
   }];
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<"tensor_reshape">,
@@ -187,11 +192,9 @@ def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<"tensor_reshape">,
     RankedTensorType getResultType() {
       return result().getType().cast<RankedTensorType>();
     }
-    SmallVector<AffineMap, 4> getReassociationMaps() {
-      return llvm::to_vector<4>(llvm::map_range(reassociation(),
-        [](Attribute a) { return a.cast<AffineMapAttr>().getValue(); }));
-    }
   }];
+  let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 def Linalg_SliceOp : Linalg_Op<"slice", [

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 4b1b8cde639e..47697f472d94 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -246,6 +246,108 @@ static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
 // ReshapeOp
 //===----------------------------------------------------------------------===//
 
+/// Collapse reassociation maps that are used in pair of reshape ops where one
+/// is a producer and other is the consumer. Only valid to use this method when
+/// both the producer and consumer are collapsing dimensions or both are
+/// expanding dimensions.
+///
+/// For example,
+///   mapsProducer = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
+///                   affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
+///                   affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
+///   mapsConsumer = [affine_map<(d0, d1, d2) -> (d0, d1)>,
+///                   affine_map<(d0, d1, d2) -> (d2)>]
+///
+/// is folded into
+///
+///   result = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
+///             affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
+static ArrayAttr collapseReassociationMaps(ArrayRef<AffineMap> mapsProducer,
+                                           ArrayRef<AffineMap> mapsConsumer,
+                                           MLIRContext *context) {
+  if (mapsProducer.size() == 0 || mapsConsumer.size() == 0 ||
+      mapsProducer[0].getNumDims() < mapsConsumer[0].getNumDims() ||
+      mapsProducer.size() != mapsConsumer[0].getNumDims())
+    return nullptr;
+  unsigned numLhsDims = mapsProducer[0].getNumDims();
+  unsigned currDim = 0;
+  SmallVector<AffineExpr, 4> reassociations;
+  SmallVector<Attribute, 4> reassociationMaps;
+  for (AffineMap rhs : mapsConsumer) {
+    for (AffineExpr rhsExpr : rhs.getResults()) {
+      AffineDimExpr dimExpr = rhsExpr.cast<AffineDimExpr>();
+      for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults();
+           i != e; ++i) {
+        reassociations.push_back(getAffineDimExpr(currDim++, context));
+      }
+    }
+    reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get(
+        numLhsDims, /*numSymbols =*/0, reassociations, context)));
+    reassociations.clear();
+  }
+  return ArrayAttr::get(reassociationMaps, context);
+}
+
+namespace {
+/// Pattern to collapse producer/consumer reshape ops that are both collapsing
+/// dimensions or are both expanding dimensions.
+template <typename ReshapeOpTy>
+struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
+  using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
+  LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcReshapeOp =
+        dyn_cast_or_null<ReshapeOpTy>(reshapeOp.src().getDefiningOp());
+    if (!srcReshapeOp)
+      return failure();
+
+    auto areReshapeOpsFoldable = [](ShapedType largerType,
+                                    ShapedType intermediateType,
+                                    ShapedType smallerType) -> bool {
+      return largerType.getRank() > intermediateType.getRank() &&
+             intermediateType.getRank() > smallerType.getRank() &&
+             smallerType.getRank() > 0;
+    };
+    // Check if producer and consumer are both expanding dims.
+    if (areReshapeOpsFoldable(reshapeOp.getResultType(), reshapeOp.getSrcType(),
+                              srcReshapeOp.getSrcType())) {
+      rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+          reshapeOp, reshapeOp.getResultType(), srcReshapeOp.src(),
+          collapseReassociationMaps(reshapeOp.getReassociationMaps(),
+                                    srcReshapeOp.getReassociationMaps(),
+                                    rewriter.getContext()));
+      return success();
+    }
+    // Check if producer and consumer are both collapsing dims.
+    else if (areReshapeOpsFoldable(srcReshapeOp.getSrcType(),
+                                   reshapeOp.getSrcType(),
+                                   reshapeOp.getResultType())) {
+      rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+          reshapeOp, reshapeOp.getResultType(), srcReshapeOp.src(),
+          collapseReassociationMaps(srcReshapeOp.getReassociationMaps(),
+                                    reshapeOp.getReassociationMaps(),
+                                    rewriter.getContext()));
+      return success();
+    }
+    return failure();
+  }
+};
+} // namespace
+
+template <typename ReshapeOpTy>
+static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp) {
+  // Fold producer-consumer reshape ops that where the operand type of the
+  // producer is same as the return type of the consumer. This can only be
+  // verified if the shapes in question are static.
+  ReshapeOpTy reshapeSrcOp =
+      dyn_cast_or_null<ReshapeOpTy>(reshapeOp.src().getDefiningOp());
+  if (reshapeSrcOp && reshapeSrcOp.getSrcType().hasStaticShape() &&
+      reshapeOp.getResultType().hasStaticShape() &&
+      reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
+    return reshapeSrcOp.src();
+  return nullptr;
+};
+
 /// Return true if the reassociation specification is valid, false otherwise.
 /// When false, the `invalidIndex` integer pointer is optionally filled with the
 /// index of the offending reassociation map.
@@ -482,6 +584,11 @@ static LogicalResult verify(ReshapeOp op) {
   return success();
 }
 
+void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                            MLIRContext *context) {
+  results.insert<CollapseReshapeOps<ReshapeOp>>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // TensorReshapeOp
 //===----------------------------------------------------------------------===//
@@ -551,6 +658,11 @@ static LogicalResult verify(TensorReshapeOp op) {
   return success();
 }
 
+void TensorReshapeOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<CollapseReshapeOps<TensorReshapeOp>>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // SliceOp
 //===----------------------------------------------------------------------===//
@@ -1010,13 +1122,18 @@ LogicalResult MatmulOp::fold(ArrayRef<Attribute>,
 OpFoldResult ReshapeOp::fold(ArrayRef<Attribute>) {
   if (succeeded(foldMemRefCast(*this)))
     return getResult();
-  return {};
+  return foldReshapeOp(*this);
 }
 OpFoldResult SliceOp::fold(ArrayRef<Attribute>) {
   if (succeeded(foldMemRefCast(*this)))
     return getResult();
   return {};
 }
+OpFoldResult TensorReshapeOp::fold(ArrayRef<Attribute>) {
+  if (succeeded(foldMemRefCast(*this)))
+    return getResult();
+  return foldReshapeOp(*this);
+}
 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
   if (succeeded(foldMemRefCast(*this)))
     return getResult();

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index d96c81097f5f..00d0aaa89d4f 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s
 
 // CHECK-LABEL: func @memref_cast(
 func @memref_cast(%a: index, %b: index) -> memref<?x?xf32> {
@@ -18,3 +18,157 @@ func @memref_cast(%a: index, %b: index) -> memref<?x?xf32> {
   linalg.matmul(%3, %3, %3) : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
   return %4: memref<?x?xf32>
 }
+
+// -----
+
+func @collapsing_tensor_reshapes(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?xf32>
+{
+  %0 = linalg.tensor_reshape %arg0
+         [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
+          affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
+          affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] :
+       tensor<?x?x?x?x?xf32> into tensor<?x?x?xf32>
+  %1 = linalg.tensor_reshape %0
+         [affine_map<(d0, d1, d2) -> (d0, d1)>,
+          affine_map<(d0, d1, d2) -> (d2)>] :
+       tensor<?x?x?xf32> into tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+//   CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+//   CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
+// CHECK-LABEL: collapsing_tensor_reshapes
+//       CHECK:   linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]]
+//   CHECK-NOT:   linalg.tensor_reshape
+
+// -----
+
+func @expanding_tensor_reshapes(%arg0 : tensor<?x?xf32>) -> tensor<?x?x?x?x?xf32>
+{
+  %0 = linalg.tensor_reshape %arg0
+         [affine_map<(d0, d1, d2) -> (d0, d1)>,
+          affine_map<(d0, d1, d2) -> (d2)>] :
+       tensor<?x?xf32> into tensor<?x?x?xf32>
+  %1 = linalg.tensor_reshape %0
+         [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
+          affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
+          affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] :
+       tensor<?x?x?xf32> into tensor<?x?x?x?x?xf32>
+  return %1 : tensor<?x?x?x?x?xf32>
+}
+//   CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+//   CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
+// CHECK-LABEL: expanding_tensor_reshapes
+//       CHECK:   linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]]
+//   CHECK-NOT:   linalg.tensor_reshape
+
+// -----
+
+func @collapsing_memref_reshapes(%arg0 : memref<?x?x?x?x?xf32>) -> memref<?x?xf32>
+{
+  %0 = linalg.reshape %arg0
+         [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
+          affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
+          affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] :
+       memref<?x?x?x?x?xf32> into memref<?x?x?xf32>
+  %1 = linalg.reshape %0
+         [affine_map<(d0, d1, d2) -> (d0, d1)>,
+          affine_map<(d0, d1, d2) -> (d2)>] :
+       memref<?x?x?xf32> into memref<?x?xf32>
+  return %1 : memref<?x?xf32>
+}
+//   CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+//   CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
+// CHECK-LABEL: collapsing_memref_reshapes
+//       CHECK:   linalg.reshape %{{.*}} [#[[MAP0]], #[[MAP1]]]
+//   CHECK-NOT:   linalg.reshape
+
+// -----
+
+func @expanding_memref_reshapes(%arg0 : memref<?x?xf32>) -> memref<?x?x?x?x?xf32>
+{
+  %0 = linalg.reshape %arg0
+         [affine_map<(d0, d1, d2) -> (d0, d1)>,
+          affine_map<(d0, d1, d2) -> (d2)>] :
+       memref<?x?xf32> into memref<?x?x?xf32>
+  %1 = linalg.reshape %0
+         [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
+          affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
+          affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] :
+       memref<?x?x?xf32> into memref<?x?x?x?x?xf32>
+  return %1 : memref<?x?x?x?x?xf32>
+}
+//   CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+//   CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
+// CHECK-LABEL: expanding_memref_reshapes
+//       CHECK:   linalg.reshape %{{.*}} [#[[MAP0]], #[[MAP1]]]
+//   CHECK-NOT:   linalg.reshape
+
+// -----
+
+func @fold_tensor_reshape(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32>
+{
+  %0 = linalg.tensor_reshape %arg0
+         [affine_map<(d0, d1, d2) -> (d0, d1)>,
+          affine_map<(d0, d1, d2) -> (d2)>] :
+       tensor<12x4xf32> into tensor<3x4x4xf32>
+  %1 = linalg.tensor_reshape %0
+         [affine_map<(d0, d1, d2) -> (d0, d1)>,
+          affine_map<(d0, d1, d2) -> (d2)>] :
+       tensor<3x4x4xf32> into tensor<12x4xf32>
+  return %1 : tensor<12x4xf32>
+}
+// CHECK-LABEL: @fold_tensor_reshape
+//   CHECK-NOT:   linalg.tensor_reshape
+
+// -----
+
+func @no_fold_tensor_reshape(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+{
+  %0 = linalg.tensor_reshape %arg0
+         [affine_map<(d0, d1, d2) -> (d0, d1)>,
+          affine_map<(d0, d1, d2) -> (d2)>] :
+       tensor<?x?xf32> into tensor<?x?x?xf32>
+  %1 = linalg.tensor_reshape %0
+         [affine_map<(d0, d1, d2) -> (d0, d1)>,
+          affine_map<(d0, d1, d2) -> (d2)>] :
+       tensor<?x?x?xf32> into tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: @no_fold_tensor_reshape
+//       CHECK:   linalg.tensor_reshape
+//       CHECK:   linalg.tensor_reshape
+
+// -----
+
+func @fold_memref_reshape(%arg0 : memref<12x4xf32>) -> memref<12x4xf32>
+{
+  %0 = linalg.reshape %arg0
+         [affine_map<(d0, d1, d2) -> (d0, d1)>,
+          affine_map<(d0, d1, d2) -> (d2)>] :
+       memref<12x4xf32> into memref<3x4x4xf32>
+  %1 = linalg.reshape %0
+         [affine_map<(d0, d1, d2) -> (d0, d1)>,
+          affine_map<(d0, d1, d2) -> (d2)>] :
+       memref<3x4x4xf32> into memref<12x4xf32>
+  return %1 : memref<12x4xf32>
+}
+// CHECK-LABEL: @fold_memref_reshape
+//   CHECK-NOT:   linalg.reshape
+
+// -----
+
+func @no_fold_memref_reshape(%arg0 : memref<?x?xf32>) -> memref<?x?xf32>
+{
+  %0 = linalg.reshape %arg0
+         [affine_map<(d0, d1, d2) -> (d0, d1)>,
+          affine_map<(d0, d1, d2) -> (d2)>] :
+       memref<?x?xf32> into memref<?x?x?xf32>
+  %1 = linalg.reshape %0
+         [affine_map<(d0, d1, d2) -> (d0, d1)>,
+          affine_map<(d0, d1, d2) -> (d2)>] :
+       memref<?x?x?xf32> into memref<?x?xf32>
+  return %1 : memref<?x?xf32>
+}
+// CHECK-LABEL: @no_fold_memref_reshape
+//       CHECK:   linalg.reshape
+//       CHECK:   linalg.reshape

diff  --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir
index 84b3bcb66940..e158e70caec8 100644
--- a/mlir/test/Dialect/Linalg/llvm.mlir
+++ b/mlir/test/Dialect/Linalg/llvm.mlir
@@ -214,72 +214,76 @@ func @matmul_vec_indexed(%A: !matrix_type_A,
 // CHECK-SAME: !llvm<"<4 x float>*">, !llvm<"<4 x float>*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
 // CHECK-SAME: !llvm<"[4 x <4 x float>]*">, !llvm<"[4 x <4 x float>]*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
 
-func @reshape_static(%arg0: memref<3x4x5xf32>) {
-  // Reshapes that expand and collapse back a contiguous tensor with some 1's.
+func @reshape_static_expand(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
+  // Reshapes that expand a contiguous tensor with some 1's.
   %0 = linalg.reshape %arg0 [affine_map<(i, j, k, l, m) -> (i, j)>,
                              affine_map<(i, j, k, l, m) -> (k)>,
                              affine_map<(i, j, k, l, m) -> (l, m)>] :
     memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
-  %r0 = linalg.reshape %0 [affine_map<(i, j, k, l, m) -> (i, j)>,
-                           affine_map<(i, j, k, l, m) -> (k)>,
-                           affine_map<(i, j, k, l, m) -> (l, m)>] :
+  return %0 : memref<1x3x4x1x5xf32>
+}
+// CHECK-LABEL: func @reshape_static_expand
+//       CHECK:    llvm.mlir.undef : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK:    llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK:    llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK:    llvm.extractvalue %{{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK:    llvm.mlir.constant(1 : index) : !llvm.i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK:    llvm.mlir.constant(3 : index) : !llvm.i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK:    llvm.mlir.constant(4 : index) : !llvm.i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK:    llvm.mlir.constant(1 : index) : !llvm.i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 3] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK:    llvm.mlir.constant(5 : index) : !llvm.i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK:    llvm.mlir.constant(60 : index) : !llvm.i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK:    llvm.mlir.constant(20 : index) : !llvm.i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK:    llvm.mlir.constant(5 : index) : !llvm.i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK:    llvm.mlir.constant(5 : index) : !llvm.i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK:    llvm.mlir.constant(1 : index) : !llvm.i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+
+func @reshape_static_collapse(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> {
+  %0 = linalg.reshape %arg0 [affine_map<(i, j, k, l, m) -> (i, j)>,
+                             affine_map<(i, j, k, l, m) -> (k)>,
+                             affine_map<(i, j, k, l, m) -> (l, m)>] :
     memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
-  return
+  return %0 : memref<3x4x5xf32>
 }
-// CHECK-LABEL: func @reshape_static(
-//       CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-//       CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-//       CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-//       CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-//       CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-//       CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-//       CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-//       CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
-//       CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-//       CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
-//       CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-//       CHECK: llvm.mlir.constant(4 : index) : !llvm.i64
-//       CHECK: llvm.insertvalue {{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-//       CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
-//       CHECK: llvm.insertvalue {{.*}}[3, 3] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-//       CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
-//       CHECK: llvm.insertvalue {{.*}}[3, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-//       CHECK: llvm.mlir.constant(60 : index) : !llvm.i64
-//       CHECK: llvm.insertvalue {{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-//       CHECK: llvm.mlir.constant(20 : index) : !llvm.i64
-//       CHECK: llvm.insertvalue {{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-//       CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
-//       CHECK: llvm.insertvalue {{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-//       CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
-//       CHECK: llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-//       CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
-//       CHECK: llvm.insertvalue {{.*}}[4, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-//       CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-//       CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-//       CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-//       CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-//       CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-//       CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-//       CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-//       CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
-//       CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-//       CHECK: llvm.mlir.constant(4 : index) : !llvm.i64
-//       CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-//       CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
-//       CHECK: llvm.insertvalue {{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-//       CHECK: llvm.mlir.constant(20 : index) : !llvm.i64
-//       CHECK: llvm.insertvalue {{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-//       CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
-//       CHECK: llvm.insertvalue {{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-//       CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
-//       CHECK: llvm.insertvalue {{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK-LABEL: func @reshape_static_collapse
+//       CHECK:    llvm.mlir.undef : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:    llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:    llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:    llvm.extractvalue %{{.*}}[2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:    llvm.mlir.constant(3 : index) : !llvm.i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:    llvm.mlir.constant(4 : index) : !llvm.i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:    llvm.mlir.constant(5 : index) : !llvm.i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:    llvm.mlir.constant(20 : index) : !llvm.i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:    llvm.mlir.constant(5 : index) : !llvm.i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+//       CHECK:    llvm.mlir.constant(1 : index) : !llvm.i64
+//       CHECK:    llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
 
-func @reshape_zero_dim(%arg0 : memref<1x1xf32>) {
+func @reshape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref<f32> {
   %0 = linalg.reshape %arg0 [] : memref<1x1xf32> into memref<f32>
-  %1 = linalg.reshape %0 [] : memref<f32> into memref<1x1xf32>
-  return
+  return %0 : memref<f32>
 }
-// CHECK-LABEL: func @reshape_zero_dim
+// CHECK-LABEL: func @reshape_fold_zero_dim
 //       CHECK:   llvm.mlir.undef : !llvm<"{ float*, float*, i64 }">
 //       CHECK:   llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
 //       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64 }">
@@ -287,6 +291,12 @@ func @reshape_zero_dim(%arg0 : memref<1x1xf32>) {
 //       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, float*, i64 }">
 //       CHECK:   llvm.extractvalue %{{.*}}[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
 //       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ float*, float*, i64 }">
+
+func @reshape_expand_zero_dim(%arg0 : memref<f32>) -> memref<1x1xf32> {
+  %0 = linalg.reshape %arg0 [] : memref<f32> into memref<1x1xf32>
+  return %0 : memref<1x1xf32>
+}
+// CHECK-LABEL: func @reshape_expand_zero_dim
 //       CHECK:   llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
 //       CHECK:   llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64 }">
 //       CHECK:   llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">


        


More information about the Mlir-commits mailing list