[Mlir-commits] [mlir] 5440d0a - [mlir][Linalg] Add folders and canonicalizers for
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 12 23:04:01 PDT 2020
Author: MaheshRavishankar
Date: 2020-05-12T23:03:26-07:00
New Revision: 5440d0a12d7f6f7f7689a7a733de7cc622605270
URL: https://github.com/llvm/llvm-project/commit/5440d0a12d7f6f7f7689a7a733de7cc622605270
DIFF: https://github.com/llvm/llvm-project/commit/5440d0a12d7f6f7f7689a7a733de7cc622605270.diff
LOG: [mlir][Linalg] Add folders and canonicalizers for
linalg.reshape/linalg.tensor_reshape operations.
Differential Revision: https://reviews.llvm.org/D79765
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Linalg/llvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 874bda002a59..1615957ff0c3 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -77,6 +77,10 @@ class Linalg_ReshapeLikeOp<string mnemonic> :
code commonExtraClassDeclaration = [{
static StringRef getReassociationAttrName() { return "reassociation"; }
+ SmallVector<AffineMap, 4> getReassociationMaps() {
+ return llvm::to_vector<4>(llvm::map_range(reassociation(), [
+ ](Attribute a) { return a.cast<AffineMapAttr>().getValue(); }));
+ }
}];
let assemblyFormat = [{
$src $reassociation attr-dict `:` type($src) `into` type(results)
@@ -137,6 +141,7 @@ def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape">,
MemRefType getResultType() { return result().getType().cast<MemRefType>(); }
}];
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<"tensor_reshape">,
@@ -187,11 +192,9 @@ def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<"tensor_reshape">,
RankedTensorType getResultType() {
return result().getType().cast<RankedTensorType>();
}
- SmallVector<AffineMap, 4> getReassociationMaps() {
- return llvm::to_vector<4>(llvm::map_range(reassociation(),
- [](Attribute a) { return a.cast<AffineMapAttr>().getValue(); }));
- }
}];
+ let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
def Linalg_SliceOp : Linalg_Op<"slice", [
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 4b1b8cde639e..47697f472d94 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -246,6 +246,108 @@ static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
// ReshapeOp
//===----------------------------------------------------------------------===//
+/// Collapse reassociation maps that are used in pair of reshape ops where one
+/// is a producer and other is the consumer. Only valid to use this method when
+/// both the producer and consumer are collapsing dimensions or both are
+/// expanding dimensions.
+///
+/// For example,
+/// mapsProducer = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
+/// affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
+/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
+/// mapsConsumer = [affine_map<(d0, d1, d2) -> (d0, d1)>,
+/// affine_map<(d0, d1, d2) -> (d2)>]
+///
+/// is folded into
+///
+/// result = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
+/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
+static ArrayAttr collapseReassociationMaps(ArrayRef<AffineMap> mapsProducer,
+ ArrayRef<AffineMap> mapsConsumer,
+ MLIRContext *context) {
+ if (mapsProducer.size() == 0 || mapsConsumer.size() == 0 ||
+ mapsProducer[0].getNumDims() < mapsConsumer[0].getNumDims() ||
+ mapsProducer.size() != mapsConsumer[0].getNumDims())
+ return nullptr;
+ unsigned numLhsDims = mapsProducer[0].getNumDims();
+ unsigned currDim = 0;
+ SmallVector<AffineExpr, 4> reassociations;
+ SmallVector<Attribute, 4> reassociationMaps;
+ for (AffineMap rhs : mapsConsumer) {
+ for (AffineExpr rhsExpr : rhs.getResults()) {
+ AffineDimExpr dimExpr = rhsExpr.cast<AffineDimExpr>();
+ for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults();
+ i != e; ++i) {
+ reassociations.push_back(getAffineDimExpr(currDim++, context));
+ }
+ }
+ reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get(
+ numLhsDims, /*numSymbols =*/0, reassociations, context)));
+ reassociations.clear();
+ }
+ return ArrayAttr::get(reassociationMaps, context);
+}
+
+namespace {
+/// Pattern to collapse producer/consumer reshape ops that are both collapsing
+/// dimensions or are both expanding dimensions.
+template <typename ReshapeOpTy>
+struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
+ using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
+ LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
+ PatternRewriter &rewriter) const override {
+ auto srcReshapeOp =
+ dyn_cast_or_null<ReshapeOpTy>(reshapeOp.src().getDefiningOp());
+ if (!srcReshapeOp)
+ return failure();
+
+ auto areReshapeOpsFoldable = [](ShapedType largerType,
+ ShapedType intermediateType,
+ ShapedType smallerType) -> bool {
+ return largerType.getRank() > intermediateType.getRank() &&
+ intermediateType.getRank() > smallerType.getRank() &&
+ smallerType.getRank() > 0;
+ };
+ // Check if producer and consumer are both expanding dims.
+ if (areReshapeOpsFoldable(reshapeOp.getResultType(), reshapeOp.getSrcType(),
+ srcReshapeOp.getSrcType())) {
+ rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+ reshapeOp, reshapeOp.getResultType(), srcReshapeOp.src(),
+ collapseReassociationMaps(reshapeOp.getReassociationMaps(),
+ srcReshapeOp.getReassociationMaps(),
+ rewriter.getContext()));
+ return success();
+ }
+ // Check if producer and consumer are both collapsing dims.
+ else if (areReshapeOpsFoldable(srcReshapeOp.getSrcType(),
+ reshapeOp.getSrcType(),
+ reshapeOp.getResultType())) {
+ rewriter.replaceOpWithNewOp<ReshapeOpTy>(
+ reshapeOp, reshapeOp.getResultType(), srcReshapeOp.src(),
+ collapseReassociationMaps(srcReshapeOp.getReassociationMaps(),
+ reshapeOp.getReassociationMaps(),
+ rewriter.getContext()));
+ return success();
+ }
+ return failure();
+ }
+};
+} // namespace
+
+template <typename ReshapeOpTy>
+static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp) {
+ // Fold producer-consumer reshape ops that where the operand type of the
+ // producer is same as the return type of the consumer. This can only be
+ // verified if the shapes in question are static.
+ ReshapeOpTy reshapeSrcOp =
+ dyn_cast_or_null<ReshapeOpTy>(reshapeOp.src().getDefiningOp());
+ if (reshapeSrcOp && reshapeSrcOp.getSrcType().hasStaticShape() &&
+ reshapeOp.getResultType().hasStaticShape() &&
+ reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
+ return reshapeSrcOp.src();
+ return nullptr;
+};
+
/// Return true if the reassociation specification is valid, false otherwise.
/// When false, the `invalidIndex` integer pointer is optionally filled with the
/// index of the offending reassociation map.
@@ -482,6 +584,11 @@ static LogicalResult verify(ReshapeOp op) {
return success();
}
+void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<CollapseReshapeOps<ReshapeOp>>(context);
+}
+
//===----------------------------------------------------------------------===//
// TensorReshapeOp
//===----------------------------------------------------------------------===//
@@ -551,6 +658,11 @@ static LogicalResult verify(TensorReshapeOp op) {
return success();
}
+void TensorReshapeOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<CollapseReshapeOps<TensorReshapeOp>>(context);
+}
+
//===----------------------------------------------------------------------===//
// SliceOp
//===----------------------------------------------------------------------===//
@@ -1010,13 +1122,18 @@ LogicalResult MatmulOp::fold(ArrayRef<Attribute>,
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute>) {
if (succeeded(foldMemRefCast(*this)))
return getResult();
- return {};
+ return foldReshapeOp(*this);
}
OpFoldResult SliceOp::fold(ArrayRef<Attribute>) {
if (succeeded(foldMemRefCast(*this)))
return getResult();
return {};
}
+OpFoldResult TensorReshapeOp::fold(ArrayRef<Attribute>) {
+ if (succeeded(foldMemRefCast(*this)))
+ return getResult();
+ return foldReshapeOp(*this);
+}
OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
if (succeeded(foldMemRefCast(*this)))
return getResult();
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index d96c81097f5f..00d0aaa89d4f 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s
// CHECK-LABEL: func @memref_cast(
func @memref_cast(%a: index, %b: index) -> memref<?x?xf32> {
@@ -18,3 +18,157 @@ func @memref_cast(%a: index, %b: index) -> memref<?x?xf32> {
linalg.matmul(%3, %3, %3) : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
return %4: memref<?x?xf32>
}
+
+// -----
+
+func @collapsing_tensor_reshapes(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?xf32>
+{
+ %0 = linalg.tensor_reshape %arg0
+ [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] :
+ tensor<?x?x?x?x?xf32> into tensor<?x?x?xf32>
+ %1 = linalg.tensor_reshape %0
+ [affine_map<(d0, d1, d2) -> (d0, d1)>,
+ affine_map<(d0, d1, d2) -> (d2)>] :
+ tensor<?x?x?xf32> into tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
+// CHECK-LABEL: collapsing_tensor_reshapes
+// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]]
+// CHECK-NOT: linalg.tensor_reshape
+
+// -----
+
+func @expanding_tensor_reshapes(%arg0 : tensor<?x?xf32>) -> tensor<?x?x?x?x?xf32>
+{
+ %0 = linalg.tensor_reshape %arg0
+ [affine_map<(d0, d1, d2) -> (d0, d1)>,
+ affine_map<(d0, d1, d2) -> (d2)>] :
+ tensor<?x?xf32> into tensor<?x?x?xf32>
+ %1 = linalg.tensor_reshape %0
+ [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] :
+ tensor<?x?x?xf32> into tensor<?x?x?x?x?xf32>
+ return %1 : tensor<?x?x?x?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
+// CHECK-LABEL: expanding_tensor_reshapes
+// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]]
+// CHECK-NOT: linalg.tensor_reshape
+
+// -----
+
+func @collapsing_memref_reshapes(%arg0 : memref<?x?x?x?x?xf32>) -> memref<?x?xf32>
+{
+ %0 = linalg.reshape %arg0
+ [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] :
+ memref<?x?x?x?x?xf32> into memref<?x?x?xf32>
+ %1 = linalg.reshape %0
+ [affine_map<(d0, d1, d2) -> (d0, d1)>,
+ affine_map<(d0, d1, d2) -> (d2)>] :
+ memref<?x?x?xf32> into memref<?x?xf32>
+ return %1 : memref<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
+// CHECK-LABEL: collapsing_memref_reshapes
+// CHECK: linalg.reshape %{{.*}} [#[[MAP0]], #[[MAP1]]]
+// CHECK-NOT: linalg.reshape
+
+// -----
+
+func @expanding_memref_reshapes(%arg0 : memref<?x?xf32>) -> memref<?x?x?x?x?xf32>
+{
+ %0 = linalg.reshape %arg0
+ [affine_map<(d0, d1, d2) -> (d0, d1)>,
+ affine_map<(d0, d1, d2) -> (d2)>] :
+ memref<?x?xf32> into memref<?x?x?xf32>
+ %1 = linalg.reshape %0
+ [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] :
+ memref<?x?x?xf32> into memref<?x?x?x?x?xf32>
+ return %1 : memref<?x?x?x?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
+// CHECK-LABEL: expanding_memref_reshapes
+// CHECK: linalg.reshape %{{.*}} [#[[MAP0]], #[[MAP1]]]
+// CHECK-NOT: linalg.reshape
+
+// -----
+
+func @fold_tensor_reshape(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32>
+{
+ %0 = linalg.tensor_reshape %arg0
+ [affine_map<(d0, d1, d2) -> (d0, d1)>,
+ affine_map<(d0, d1, d2) -> (d2)>] :
+ tensor<12x4xf32> into tensor<3x4x4xf32>
+ %1 = linalg.tensor_reshape %0
+ [affine_map<(d0, d1, d2) -> (d0, d1)>,
+ affine_map<(d0, d1, d2) -> (d2)>] :
+ tensor<3x4x4xf32> into tensor<12x4xf32>
+ return %1 : tensor<12x4xf32>
+}
+// CHECK-LABEL: @fold_tensor_reshape
+// CHECK-NOT: linalg.tensor_reshape
+
+// -----
+
+func @no_fold_tensor_reshape(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+{
+ %0 = linalg.tensor_reshape %arg0
+ [affine_map<(d0, d1, d2) -> (d0, d1)>,
+ affine_map<(d0, d1, d2) -> (d2)>] :
+ tensor<?x?xf32> into tensor<?x?x?xf32>
+ %1 = linalg.tensor_reshape %0
+ [affine_map<(d0, d1, d2) -> (d0, d1)>,
+ affine_map<(d0, d1, d2) -> (d2)>] :
+ tensor<?x?x?xf32> into tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: @no_fold_tensor_reshape
+// CHECK: linalg.tensor_reshape
+// CHECK: linalg.tensor_reshape
+
+// -----
+
+func @fold_memref_reshape(%arg0 : memref<12x4xf32>) -> memref<12x4xf32>
+{
+ %0 = linalg.reshape %arg0
+ [affine_map<(d0, d1, d2) -> (d0, d1)>,
+ affine_map<(d0, d1, d2) -> (d2)>] :
+ memref<12x4xf32> into memref<3x4x4xf32>
+ %1 = linalg.reshape %0
+ [affine_map<(d0, d1, d2) -> (d0, d1)>,
+ affine_map<(d0, d1, d2) -> (d2)>] :
+ memref<3x4x4xf32> into memref<12x4xf32>
+ return %1 : memref<12x4xf32>
+}
+// CHECK-LABEL: @fold_memref_reshape
+// CHECK-NOT: linalg.reshape
+
+// -----
+
+func @no_fold_memref_reshape(%arg0 : memref<?x?xf32>) -> memref<?x?xf32>
+{
+ %0 = linalg.reshape %arg0
+ [affine_map<(d0, d1, d2) -> (d0, d1)>,
+ affine_map<(d0, d1, d2) -> (d2)>] :
+ memref<?x?xf32> into memref<?x?x?xf32>
+ %1 = linalg.reshape %0
+ [affine_map<(d0, d1, d2) -> (d0, d1)>,
+ affine_map<(d0, d1, d2) -> (d2)>] :
+ memref<?x?x?xf32> into memref<?x?xf32>
+ return %1 : memref<?x?xf32>
+}
+// CHECK-LABEL: @no_fold_memref_reshape
+// CHECK: linalg.reshape
+// CHECK: linalg.reshape
diff --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir
index 84b3bcb66940..e158e70caec8 100644
--- a/mlir/test/Dialect/Linalg/llvm.mlir
+++ b/mlir/test/Dialect/Linalg/llvm.mlir
@@ -214,72 +214,76 @@ func @matmul_vec_indexed(%A: !matrix_type_A,
// CHECK-SAME: !llvm<"<4 x float>*">, !llvm<"<4 x float>*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
// CHECK-SAME: !llvm<"[4 x <4 x float>]*">, !llvm<"[4 x <4 x float>]*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
-func @reshape_static(%arg0: memref<3x4x5xf32>) {
- // Reshapes that expand and collapse back a contiguous tensor with some 1's.
+func @reshape_static_expand(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
+ // Reshapes that expand a contiguous tensor with some 1's.
%0 = linalg.reshape %arg0 [affine_map<(i, j, k, l, m) -> (i, j)>,
affine_map<(i, j, k, l, m) -> (k)>,
affine_map<(i, j, k, l, m) -> (l, m)>] :
memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
- %r0 = linalg.reshape %0 [affine_map<(i, j, k, l, m) -> (i, j)>,
- affine_map<(i, j, k, l, m) -> (k)>,
- affine_map<(i, j, k, l, m) -> (l, m)>] :
+ return %0 : memref<1x3x4x1x5xf32>
+}
+// CHECK-LABEL: func @reshape_static_expand
+// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.mlir.constant(4 : index) : !llvm.i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 3] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.mlir.constant(60 : index) : !llvm.i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.mlir.constant(20 : index) : !llvm.i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+
+func @reshape_static_collapse(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> {
+ %0 = linalg.reshape %arg0 [affine_map<(i, j, k, l, m) -> (i, j)>,
+ affine_map<(i, j, k, l, m) -> (k)>,
+ affine_map<(i, j, k, l, m) -> (l, m)>] :
memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
- return
+ return %0 : memref<3x4x5xf32>
}
-// CHECK-LABEL: func @reshape_static(
-// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
-// CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
-// CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-// CHECK: llvm.mlir.constant(4 : index) : !llvm.i64
-// CHECK: llvm.insertvalue {{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
-// CHECK: llvm.insertvalue {{.*}}[3, 3] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
-// CHECK: llvm.insertvalue {{.*}}[3, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-// CHECK: llvm.mlir.constant(60 : index) : !llvm.i64
-// CHECK: llvm.insertvalue {{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-// CHECK: llvm.mlir.constant(20 : index) : !llvm.i64
-// CHECK: llvm.insertvalue {{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
-// CHECK: llvm.insertvalue {{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
-// CHECK: llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
-// CHECK: llvm.insertvalue {{.*}}[4, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
-// CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
-// CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.mlir.constant(4 : index) : !llvm.i64
-// CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
-// CHECK: llvm.insertvalue {{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.mlir.constant(20 : index) : !llvm.i64
-// CHECK: llvm.insertvalue {{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
-// CHECK: llvm.insertvalue {{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
-// CHECK: llvm.insertvalue {{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK-LABEL: func @reshape_static_collapse
+// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.mlir.constant(4 : index) : !llvm.i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.mlir.constant(20 : index) : !llvm.i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
+// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
-func @reshape_zero_dim(%arg0 : memref<1x1xf32>) {
+func @reshape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref<f32> {
%0 = linalg.reshape %arg0 [] : memref<1x1xf32> into memref<f32>
- %1 = linalg.reshape %0 [] : memref<f32> into memref<1x1xf32>
- return
+ return %0 : memref<f32>
}
-// CHECK-LABEL: func @reshape_zero_dim
+// CHECK-LABEL: func @reshape_fold_zero_dim
// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64 }">
// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64 }">
@@ -287,6 +291,12 @@ func @reshape_zero_dim(%arg0 : memref<1x1xf32>) {
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, float*, i64 }">
// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ float*, float*, i64 }">
+
+func @reshape_expand_zero_dim(%arg0 : memref<f32>) -> memref<1x1xf32> {
+ %0 = linalg.reshape %arg0 [] : memref<f32> into memref<1x1xf32>
+ return %0 : memref<1x1xf32>
+}
+// CHECK-LABEL: func @reshape_expand_zero_dim
// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64 }">
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
More information about the Mlir-commits
mailing list