[Mlir-commits] [mlir] 81264df - [mlir][Linalg] Add utility method to reshape ops to express output shape in terms of input shape.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 16 13:42:42 PST 2021


Author: MaheshRavishankar
Date: 2021-02-16T13:42:08-08:00
New Revision: 81264dfbe80df08668a325a61613b64243b99c01

URL: https://github.com/llvm/llvm-project/commit/81264dfbe80df08668a325a61613b64243b99c01
DIFF: https://github.com/llvm/llvm-project/commit/81264dfbe80df08668a325a61613b64243b99c01.diff

LOG: [mlir][Linalg] Add utility method to reshape ops to express output shape in terms of input shape.

Resolving the dim of outputs of a tensor_reshape op in terms of its
input shape allows the op to be eliminated when its used only in its
dims. The init_tensor -> tensor_reshape canonicalization can be
simplified to use the dims of the output of the tensor_reshape which
gets canonicalized away later making the tensor_reshape dead.

Differential Revision: https://reviews.llvm.org/D96635

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
index a98336382fe6..f22b00da01c9 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
@@ -107,6 +107,13 @@ SmallVector<AffineExpr, 4> concat(ArrayRef<AffineExpr> a,
 void getDimsOfType(Operation *op, StringRef iteratorTypeName,
                    SmallVectorImpl<AffineExpr> &res);
 
+/// For reshape operation, compute the shape of the output based on the result
+/// type and shape of the input.
+SmallVector<Value, 4>
+getReshapeOutputShapeFromInputShape(OpBuilder &b, Location loc, Value src,
+                                    ArrayRef<int64_t> dstStaticShape,
+                                    ArrayRef<AffineMap> reassociation);
+
 namespace detail {
 LogicalResult verifyStructuredOpInterface(Operation *op);
 } // namespace detail

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 7212700d641e..dc99e217aeb4 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -342,10 +342,15 @@ class Linalg_ReshapeLikeOp<string mnemonic, list<OpTrait> traits = []> :
     SmallVector<ReassociationExprs, 4> getReassociationExprs() {
       return
         llvm::to_vector<4>(llvm::map_range(reassociation(),
-	  [](Attribute a) {
-	    return llvm::to_vector<2>(
-	      a.cast<AffineMapAttr>().getValue().getResults());
-	  }));
+          [](Attribute a) {
+            return llvm::to_vector<2>(
+              a.cast<AffineMapAttr>().getValue().getResults());
+          }));
+    }
+    SmallVector<Value, 4> getOutputShape(OpBuilder &b, Location loc) {
+      return getReshapeOutputShapeFromInputShape(
+          b, loc, src(), getResultType().getShape(),
+          getReassociationMaps());
     }
   }];
   let assemblyFormat = [{

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 0efcddfbe1c6..7c348672dc37 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -605,85 +605,6 @@ Type InitTensorOp::inferResultType(ArrayRef<int64_t> staticSizes,
   return RankedTensorType::get(staticSizes, elementType);
 }
 
-namespace {
-/// Change the type of the result of a `linalg.init_tensor` by making the result
-/// type statically sized along dimension that in the original operation where
-/// defined as dynamic, but the size was defined using a `constant` op. For
-/// example
-///
-///  %c5 = constant 5: index
-///  %0 = linalg.init_tensor [%arg0, %c5] : tensor<?x?xf32>
-///
-///  to
-///
-///  %0 = linalg.init_tensor [%arg0, 5] : tensor<?x5xf32>
-struct ReplaceStaticShapeDims : OpRewritePattern<InitTensorOp> {
-  using OpRewritePattern<InitTensorOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(InitTensorOp op,
-                                PatternRewriter &rewriter) const override {
-    SmallVector<Value, 4> dynamicSizes;
-    SmallVector<int64_t, 4> staticSizes;
-    for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) {
-      // If the size is already static, nothing to do.
-      if (!op.isDynamicSize(i)) {
-        staticSizes.push_back(op.getStaticSize(i));
-        continue;
-      }
-
-      // If the size is dynamic but defined using a `constant` op, get the
-      // constant value to find the static size to use.
-      unsigned operandNum = op.getIndexOfDynamicSize(i);
-      Value sizeOperand = op.getOperand(operandNum);
-      if (auto constantIndexOp = sizeOperand.getDefiningOp<ConstantIndexOp>()) {
-        staticSizes.push_back(constantIndexOp.getValue());
-        continue;
-      }
-
-      // Fallback case. Keep the size dynamic.
-      dynamicSizes.push_back(sizeOperand);
-      staticSizes.push_back(ShapedType::kDynamicSize);
-    }
-    RankedTensorType newType =
-        RankedTensorType::get(staticSizes, op.getType().getElementType());
-    if (newType == op.getType())
-      return failure();
-    auto newOp =
-        rewriter.create<InitTensorOp>(op.getLoc(), newType, dynamicSizes,
-                                      rewriter.getI64ArrayAttr(staticSizes));
-    rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
-    return success();
-  }
-};
-
-/// Canonicalize a `linalg.init_tensor` -> `dim` pattern by replacing the `dim`
-/// with
-/// - A constant value if the size is static along the dimension.
-/// - The dynamic value that defines the size of the result of
-///   `linalg.init_tensor` op.
-struct ReplaceDimOfInitTensorOp : public OpRewritePattern<DimOp> {
-  using OpRewritePattern<DimOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(DimOp dimOp,
-                                PatternRewriter &rewriter) const override {
-    auto initTensorOp = dimOp.memrefOrTensor().getDefiningOp<InitTensorOp>();
-    if (!initTensorOp)
-      return failure();
-    auto dimIndex = dimOp.index().getDefiningOp<ConstantIndexOp>();
-    if (!dimIndex)
-      return failure();
-    int64_t index = dimIndex.getValue();
-    if (!initTensorOp.isDynamicSize(index)) {
-      rewriter.replaceOpWithNewOp<ConstantIndexOp>(
-          dimOp, initTensorOp.getStaticSize(index));
-    } else {
-      rewriter.replaceOp(dimOp, initTensorOp.getDynamicSize(index));
-    }
-    return success();
-  }
-};
-} // namespace
-
 static Value getCollapsedInitTensor(OpBuilder &builder,
                                     TensorReshapeOp reshapeOp) {
   Location loc = reshapeOp.getLoc();
@@ -773,6 +694,85 @@ static Value getExpandedInitTensor(OpBuilder &builder,
                                       srcType.getElementType());
 }
 
+namespace {
+/// Change the type of the result of a `linalg.init_tensor` by making the result
+/// type statically sized along dimension that in the original operation where
+/// defined as dynamic, but the size was defined using a `constant` op. For
+/// example
+///
+///  %c5 = constant 5: index
+///  %0 = linalg.init_tensor [%arg0, %c5] : tensor<?x?xf32>
+///
+///  to
+///
+///  %0 = linalg.init_tensor [%arg0, 5] : tensor<?x5xf32>
+struct ReplaceStaticShapeDims : OpRewritePattern<InitTensorOp> {
+  using OpRewritePattern<InitTensorOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(InitTensorOp op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<Value, 4> dynamicSizes;
+    SmallVector<int64_t, 4> staticSizes;
+    for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) {
+      // If the size is already static, nothing to do.
+      if (!op.isDynamicSize(i)) {
+        staticSizes.push_back(op.getStaticSize(i));
+        continue;
+      }
+
+      // If the size is dynamic but defined using a `constant` op, get the
+      // constant value to find the static size to use.
+      unsigned operandNum = op.getIndexOfDynamicSize(i);
+      Value sizeOperand = op.getOperand(operandNum);
+      if (auto constantIndexOp = sizeOperand.getDefiningOp<ConstantIndexOp>()) {
+        staticSizes.push_back(constantIndexOp.getValue());
+        continue;
+      }
+
+      // Fallback case. Keep the size dynamic.
+      dynamicSizes.push_back(sizeOperand);
+      staticSizes.push_back(ShapedType::kDynamicSize);
+    }
+    RankedTensorType newType =
+        RankedTensorType::get(staticSizes, op.getType().getElementType());
+    if (newType == op.getType())
+      return failure();
+    auto newOp =
+        rewriter.create<InitTensorOp>(op.getLoc(), newType, dynamicSizes,
+                                      rewriter.getI64ArrayAttr(staticSizes));
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
+    return success();
+  }
+};
+
+/// Canonicalize a `linalg.init_tensor` -> `dim` pattern by replacing the `dim`
+/// with
+/// - A constant value if the size is static along the dimension.
+/// - The dynamic value that defines the size of the result of
+///   `linalg.init_tensor` op.
+struct ReplaceDimOfInitTensorOp : public OpRewritePattern<DimOp> {
+  using OpRewritePattern<DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DimOp dimOp,
+                                PatternRewriter &rewriter) const override {
+    auto initTensorOp = dimOp.memrefOrTensor().getDefiningOp<InitTensorOp>();
+    if (!initTensorOp)
+      return failure();
+    auto dimIndex = dimOp.index().getDefiningOp<ConstantIndexOp>();
+    if (!dimIndex)
+      return failure();
+    int64_t index = dimIndex.getValue();
+    if (!initTensorOp.isDynamicSize(index)) {
+      rewriter.replaceOpWithNewOp<ConstantIndexOp>(
+          dimOp, initTensorOp.getStaticSize(index));
+    } else {
+      rewriter.replaceOp(dimOp, initTensorOp.getDynamicSize(index));
+    }
+    return success();
+  }
+};
+} // namespace
+
 namespace {
 /// Since `init_tensor` operation creates a tensor needed only for its shape, a
 /// subtensor of this is also needed only for its shape. The result can be
@@ -803,17 +803,13 @@ struct FoldInitTensorWithTensorReshapeOp
                                 PatternRewriter &rewriter) const override {
     if (!reshapeOp.src().getDefiningOp<InitTensorOp>())
       return failure();
-    RankedTensorType collapsedType = reshapeOp.getSrcType();
-    RankedTensorType expandedType = reshapeOp.getResultType();
-    bool isCollapsed = expandedType.getRank() < collapsedType.getRank();
-    if (isCollapsed)
-      std::swap(collapsedType, expandedType);
-    Value initTensorOp = isCollapsed
-                             ? getCollapsedInitTensor(rewriter, reshapeOp)
-                             : getExpandedInitTensor(rewriter, reshapeOp);
-    if (!initTensorOp)
-      return failure();
-    rewriter.replaceOp(reshapeOp, initTensorOp);
+    Location loc = reshapeOp.getLoc();
+    SmallVector<Value, 4> resultShapeValues =
+        reshapeOp.getOutputShape(rewriter, loc);
+    Value initTensor = rewriter.create<InitTensorOp>(
+        loc, resultShapeValues, reshapeOp.getResultType().getElementType());
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(
+        reshapeOp, reshapeOp.getResultType(), initTensor);
     return success();
   }
 };
@@ -1255,6 +1251,141 @@ convertReassociationIndicesToMaps(
   return reassociationMaps;
 }
 
+/// For reshape op compute the shape at dimension `dimIndex` of the output in
+/// terms of shape of the `src`, when the reshape op is a collapsing
+/// operation. It is the product of the shape of the collapsed dimensions of the
+/// `src`.
+static Value
+getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc,
+                                    int64_t dimIndex, Value src,
+                                    ArrayRef<AffineMap> reassociationMap) {
+  AffineMap map = reassociationMap[dimIndex];
+  unsigned startPos =
+      map.getResults().front().cast<AffineDimExpr>().getPosition();
+  unsigned endPos = map.getResults().back().cast<AffineDimExpr>().getPosition();
+  AffineExpr expr;
+  SmallVector<Value, 2> dynamicDims;
+  for (auto dim : llvm::seq(startPos, endPos + 1)) {
+    dynamicDims.push_back(builder.create<DimOp>(loc, src, dim));
+    AffineExpr currExpr = builder.getAffineSymbolExpr(dim - startPos);
+    expr = (expr ? expr * currExpr : currExpr);
+  }
+  return applyMapToValues(builder, loc,
+                          AffineMap::get(0, endPos - startPos + 1, expr),
+                          dynamicDims)[0];
+}
+
+/// Given the `src` of a collapsing reshape op and its reassociation maps,
+/// compute the shape of the result of the reshape.
+static SmallVector<Value, 4> getCollapsedOutputShapeFromInputShape(
+    OpBuilder &builder, Location loc, Value src,
+    ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
+  return llvm::to_vector<4>(llvm::map_range(
+      llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
+        return getCollapsedOutputDimFromInputShape(builder, loc, dim, src,
+                                                   reassociation);
+      }));
+}
+
+/// Compute a map that for a given dimension of the expanded type gives the
+/// dimension in the collapsed type it maps to. Essentially its the inverse of
+/// the `reassocation` maps.
+static llvm::DenseMap<int64_t, int64_t>
+getExpandedDimToCollapsedDimMap(ArrayRef<AffineMap> reassociation) {
+  llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
+  for (auto map : enumerate(reassociation)) {
+    unsigned startPos =
+        map.value().getResults().front().cast<AffineDimExpr>().getPosition();
+    unsigned endPos =
+        map.value().getResults().back().cast<AffineDimExpr>().getPosition();
+    for (auto dim : llvm::seq(startPos, endPos + 1)) {
+      expandedDimToCollapsedDim[dim] = map.index();
+    }
+  }
+  return expandedDimToCollapsedDim;
+}
+
+/// For an expanding reshape op, compute the value for a dimension of the output
+/// from the shape of the input.
+static Value getExpandedOutputDimFromInputShape(
+    OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
+    ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation,
+    llvm::DenseMap<int64_t, int64_t> &expandedDimToCollapsedDim) {
+  if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) {
+    return builder.create<ConstantIndexOp>(loc, dstStaticShape[dimIndex]);
+  }
+  unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex];
+  unsigned startPos = reassociation[sourceDimPos]
+                          .getResults()
+                          .front()
+                          .cast<AffineDimExpr>()
+                          .getPosition();
+  unsigned endPos = reassociation[sourceDimPos]
+                        .getResults()
+                        .back()
+                        .cast<AffineDimExpr>()
+                        .getPosition();
+  int64_t linearizedStaticDim = 1;
+  for (auto d :
+       llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) {
+    if (d.index() + startPos == static_cast<unsigned>(dimIndex))
+      continue;
+    assert(!ShapedType::isDynamic(d.value()) &&
+           "single dimension cannot be expanded into multiple dynamic "
+           "dimensions");
+    linearizedStaticDim *= d.value();
+  }
+  Value sourceDim = builder.create<DimOp>(loc, src, sourceDimPos);
+  return applyMapToValues(
+      builder, loc,
+      AffineMap::get(
+          0, 1, builder.getAffineSymbolExpr(0).floorDiv(linearizedStaticDim)),
+      sourceDim)[0];
+}
+
+/// Given the `src` of an expanding reshape op, the reassociation maps and the
+/// result type, compute the shape of the result of the reshape.
+static SmallVector<Value, 4> getExpandedOutputShapeFromInputShape(
+    OpBuilder &builder, Location loc, Value src,
+    ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
+  llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim =
+      getExpandedDimToCollapsedDimMap(reassociation);
+  return llvm::to_vector<4>(llvm::map_range(
+      llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
+        return getExpandedOutputDimFromInputShape(builder, loc, dim, src,
+                                                  dstStaticShape, reassociation,
+                                                  expandedDimToCollapsedDim);
+      }));
+}
+
+SmallVector<Value, 4> mlir::linalg::getReshapeOutputShapeFromInputShape(
+    OpBuilder &builder, Location loc, Value src,
+    ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassocation) {
+  return dstStaticShape.size() >
+                 static_cast<size_t>(src.getType().cast<ShapedType>().getRank())
+             ? getExpandedOutputShapeFromInputShape(
+                   builder, loc, src, dstStaticShape, reassocation)
+             : getCollapsedOutputShapeFromInputShape(
+                   builder, loc, src, dstStaticShape, reassocation);
+}
+
+/// For a reshape op, compute the value of a given dimension of the output
+/// (`dimIndex`) from the shape of the inputs and type of the result.
+static Value getReshapeOutputDimFromInputShape(
+    OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
+    ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
+  if (dstStaticShape.size() >
+      static_cast<size_t>(src.getType().cast<ShapedType>().getRank())) {
+    llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim =
+        getExpandedDimToCollapsedDimMap(reassociation);
+    return getExpandedOutputDimFromInputShape(builder, loc, dimIndex, src,
+                                              dstStaticShape, reassociation,
+                                              expandedDimToCollapsedDim);
+  }
+  return getCollapsedOutputDimFromInputShape(builder, loc, dimIndex, src,
+                                             reassociation);
+}
+
 void mlir::linalg::ReshapeOp::build(OpBuilder &b, OperationState &result,
                                     Value src,
                                     ArrayRef<ReassociationExprs> reassociation,
@@ -1478,12 +1609,35 @@ struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
     return success();
   }
 };
+
+/// Canonicalize dim ops that use the output shape with dim of the input.
+struct ReplaceDimOfReshapeOpResult : OpRewritePattern<DimOp> {
+  using OpRewritePattern<DimOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(DimOp dimOp,
+                                PatternRewriter &rewriter) const override {
+    Value dimValue = dimOp.memrefOrTensor();
+    Optional<int64_t> dimIndex = dimOp.getConstantIndex();
+    if (!dimIndex)
+      return failure();
+
+    auto reshapeOp = dimValue.getDefiningOp<TensorReshapeOp>();
+    if (!reshapeOp)
+      return failure();
+
+    rewriter.replaceOp(dimOp,
+                       getReshapeOutputDimFromInputShape(
+                           rewriter, dimOp.getLoc(), *dimIndex, reshapeOp.src(),
+                           reshapeOp.getResultType().getShape(),
+                           reshapeOp.getReassociationMaps()));
+    return success();
+  }
+};
 } // namespace
 
 void TensorReshapeOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
-  results.insert<CollapseReshapeOps<TensorReshapeOp>, FoldReshapeWithConstant>(
-      context);
+  results.insert<CollapseReshapeOps<TensorReshapeOp>, FoldReshapeWithConstant,
+                 ReplaceDimOfReshapeOpResult>(context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 75abef70cd4e..2fb5eb3086e6 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -560,10 +560,10 @@ func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> {
      tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
   return %1 : tensor<2x3x5x4x?x7xf32>
 }
+//      CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
 //      CHECK: func @init_tensor_reshape_expansion
 // CHECK-SAME:   %[[ARG0:.+]]: index
-//      CHECK:   %[[C28:.+]] = constant 28 : index
-//      CHECK:   %[[T0:.+]] = divi_unsigned %[[ARG0]], %[[C28]]
+//      CHECK:   %[[T0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
 //      CHECK:   %[[T1:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[T0]], 7]
 //      CHECK:   return %[[T1]]
 
@@ -578,10 +578,10 @@ func @init_tensor_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {
     tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32>
   return %1 : tensor<6x5x?xf32>
 }
+//      CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)>
 //      CHECK: func @init_tensor_reshape_collapse
 // CHECK-SAME:   %[[ARG0:.+]]: index
-//      CHECK:   %[[C28:.+]] = constant 28 : index
-//      CHECK:   %[[T0:.+]] = muli %[[ARG0]], %[[C28]]
+//      CHECK:   %[[T0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
 //      CHECK:   %[[T1:.+]] = linalg.init_tensor [6, 5, %[[T0]]]
 //      CHECK:   return %[[T1]]
 
@@ -716,3 +716,54 @@ func @dead_linalg_tensor(%arg0 : tensor<7x7xi32>, %arg1 : tensor<7x7xf32>,
   } : tensor<?x?xf32> to tensor<2x4xf32>
   return
 }
+
+// -----
+
+func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>) -> (index, index, index)
+{
+  %c1 = constant 1 : index
+  %c3 = constant 3 : index
+  %c4 = constant 4 : index
+  %0 = linalg.tensor_reshape %arg0
+    [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>,
+     affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>,
+     affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] :
+     tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
+  %1 = dim %0, %c1 : tensor<2x3x5x4x?x7xf32>
+  %2 = dim %0, %c3 : tensor<2x3x5x4x?x7xf32>
+  %3 = dim %0, %c4 : tensor<2x3x5x4x?x7xf32>
+  return %1, %2, %3 : index, index, index
+}
+//      CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
+//      CHECK: func @dim_reshape_expansion
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32>
+//  CHECK-DAG:   %[[C2:.+]] = constant 2 : index
+//  CHECK-DAG:   %[[C3:.+]] = constant 3 : index
+//  CHECK-DAG:   %[[C4:.+]] = constant 4 : index
+//      CHECK:   %[[D0:.+]] = dim %[[ARG0]], %[[C2]]
+//      CHECK:   %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
+//      CHECK:   return %[[C3]], %[[C4]], %[[D1]]
+
+// -----
+
+func @dim_reshape_collapse(%arg0 : tensor<2x3x5x4x?x7xf32>) -> (index, index)
+{
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %0 = linalg.tensor_reshape %arg0
+    [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>,
+     affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>,
+     affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] :
+     tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32>
+  %1 = dim %0, %c1 : tensor<6x5x?xf32>
+  %2 = dim %0, %c2 : tensor<6x5x?xf32>
+  return %1, %2 : index, index
+}
+//      CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)>
+//      CHECK: func @dim_reshape_collapse
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x3x5x4x?x7xf32>
+//  CHECK-DAG:   %[[C4:.+]] = constant 4 : index
+//  CHECK-DAG:   %[[C5:.+]] = constant 5 : index
+//      CHECK:   %[[D0:.+]] = dim %[[ARG0]], %[[C4]]
+//      CHECK:   %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
+//      CHECK:   return %[[C5]], %[[D1]]


        


More information about the Mlir-commits mailing list