[llvm-branch-commits] [mlir] 9c0dc0b - [mlir][Linalg] Fold init_tensor -> linalg.tensor_reshape.

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Jan 11 09:27:03 PST 2021


Author: MaheshRavishankar
Date: 2021-01-11T09:22:35-08:00
New Revision: 9c0dc0b2c1cc973056237bdd80dbba749941ea63

URL: https://github.com/llvm/llvm-project/commit/9c0dc0b2c1cc973056237bdd80dbba749941ea63
DIFF: https://github.com/llvm/llvm-project/commit/9c0dc0b2c1cc973056237bdd80dbba749941ea63.diff

LOG: [mlir][Linalg] Fold init_tensor -> linalg.tensor_reshape.

Reshaping an init_tensor can be folded to a init_tensor op of the
final type.

Differential Revision: https://reviews.llvm.org/D93773

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 8a97753e1a5c..8732065bb042 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -718,9 +718,123 @@ struct ReplaceDimOfInitTensorOp : public OpRewritePattern<DimOp> {
 };
 } // namespace
 
+static Value getCollapsedInitTensor(OpBuilder &builder,
+                                    TensorReshapeOp reshapeOp) {
+  Location loc = reshapeOp.getLoc();
+  SmallVector<Value, 4> dynamicShapes;
+  SmallVector<int64_t, 4> staticShapes;
+  auto reassociation = reshapeOp.getReassociationMaps();
+  Value src = reshapeOp.src();
+  RankedTensorType srcType = reshapeOp.getSrcType();
+  ArrayRef<int64_t> srcShape = srcType.getShape();
+  for (auto map : reassociation) {
+    Value linearizedDynamicDim = nullptr;
+    int64_t linearizedStaticDim = 1;
+    for (unsigned i : llvm::map_range(map.getResults(), [](AffineExpr e) {
+           return e.cast<AffineDimExpr>().getPosition();
+         })) {
+      if (ShapedType::isDynamic(srcShape[i])) {
+        Value shapeVal = builder.create<DimOp>(loc, src, i);
+        if (linearizedDynamicDim) {
+          linearizedDynamicDim =
+              builder.create<MulIOp>(loc, linearizedDynamicDim, shapeVal);
+        } else {
+          linearizedDynamicDim = shapeVal;
+        }
+      } else {
+        linearizedStaticDim *= srcShape[i];
+      }
+    }
+    if (linearizedDynamicDim) {
+      if (linearizedStaticDim != 1) {
+        linearizedDynamicDim = builder.create<MulIOp>(
+            loc, linearizedDynamicDim,
+            builder.create<ConstantIndexOp>(loc, linearizedStaticDim));
+      }
+      dynamicShapes.push_back(linearizedDynamicDim);
+      staticShapes.push_back(ShapedType::kDynamicSize);
+    } else {
+      staticShapes.push_back(linearizedStaticDim);
+    }
+  }
+  return builder.create<InitTensorOp>(loc, dynamicShapes, staticShapes,
+                                      srcType.getElementType());
+}
+
+static Value getExpandedInitTensor(OpBuilder &builder,
+                                   TensorReshapeOp reshapeOp) {
+  SmallVector<Value, 4> dynamicShapes;
+  SmallVector<int64_t, 4> staticShapes;
+  auto reassociation = reshapeOp.getReassociationMaps();
+  Value src = reshapeOp.src();
+  RankedTensorType srcType = reshapeOp.getSrcType();
+  ArrayRef<int64_t> srcShape = srcType.getShape();
+  ArrayRef<int64_t> dstShape = reshapeOp.getResultType().getShape();
+  Location loc = reshapeOp.getLoc();
+  for (auto map : enumerate(reassociation)) {
+    int64_t linearizedStaticDim = 1;
+    bool hasDynamic = false;
+    for (unsigned i :
+         llvm::map_range(map.value().getResults(), [](AffineExpr e) {
+           return e.cast<AffineDimExpr>().getPosition();
+         })) {
+      if (ShapedType::isDynamic(dstShape[i])) {
+        // Only one of the dimensions of the expanded shape should be dynamic.
+        if (hasDynamic)
+          return nullptr;
+        hasDynamic = true;
+        staticShapes.push_back(ShapedType::kDynamicSize);
+        continue;
+      }
+      staticShapes.push_back(dstShape[i]);
+      linearizedStaticDim *= dstShape[i];
+    }
+    if (hasDynamic) {
+      // If the expanded dimensions has a dynamic shape, the src shape must be
+      // dynamic as well.
+      if (!ShapedType::isDynamic(srcShape[map.index()]))
+        return nullptr;
+      Value dynamicDim = builder.create<DimOp>(loc, src, map.index());
+      if (linearizedStaticDim != 1) {
+        dynamicDim = builder.create<UnsignedDivIOp>(
+            loc, dynamicDim,
+            builder.create<ConstantIndexOp>(loc, linearizedStaticDim));
+      }
+      dynamicShapes.push_back(dynamicDim);
+    }
+  }
+  return builder.create<InitTensorOp>(loc, dynamicShapes, staticShapes,
+                                      srcType.getElementType());
+}
+
+namespace {
+struct FoldWithTensorReshapeOp : public OpRewritePattern<TensorReshapeOp> {
+  using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
+                                PatternRewriter &rewriter) const override {
+    if (!reshapeOp.src().getDefiningOp<InitTensorOp>())
+      return failure();
+    RankedTensorType collapsedType = reshapeOp.getSrcType();
+    RankedTensorType expandedType = reshapeOp.getResultType();
+    bool isCollapsed = expandedType.getRank() < collapsedType.getRank();
+    if (isCollapsed)
+      std::swap(collapsedType, expandedType);
+    Value initTensorOp = isCollapsed
+                             ? getCollapsedInitTensor(rewriter, reshapeOp)
+                             : getExpandedInitTensor(rewriter, reshapeOp);
+    if (!initTensorOp)
+      return failure();
+    rewriter.replaceOp(reshapeOp, initTensorOp);
+    return success();
+  }
+};
+} // namespace
+
 void InitTensorOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
-  results.insert<ReplaceDimOfInitTensorOp, ReplaceStaticShapeDims>(context);
+  results.insert<FoldWithTensorReshapeOp, ReplaceDimOfInitTensorOp,
+                 ReplaceStaticShapeDims>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1043,23 +1157,23 @@ static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
   ArrayRef<int64_t> expandedShape = expandedType.getShape();
   unsigned expandedDimStart = 0;
   for (auto map : llvm::enumerate(op.getReassociationMaps())) {
-    Optional<int64_t> dynamicDims;
+    Optional<int64_t> dynamicShape;
     int64_t linearizedStaticShape = 1;
     for (auto dim : llvm::enumerate(expandedShape.slice(
              expandedDimStart, map.value().getNumResults()))) {
       if (ShapedType::isDynamic(dim.value())) {
-        if (isExpandingReshape && dynamicDims) {
+        if (isExpandingReshape && dynamicShape) {
           return op->emitOpError("invalid to have a single dimension (")
                  << map.index() << ") expanded into multiple dynamic dims ("
-                 << expandedDimStart + dynamicDims.getValue() << ","
+                 << expandedDimStart + dynamicShape.getValue() << ","
                  << expandedDimStart + dim.index() << ")";
         }
-        dynamicDims = dim.index();
+        dynamicShape = dim.index();
       } else {
         linearizedStaticShape *= dim.value();
       }
     }
-    if (dynamicDims) {
+    if (dynamicShape) {
       if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
         return op->emitOpError("expected dimension ")
                << map.index()

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 4102a1326b96..6b806c801341 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -413,3 +413,39 @@ func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,
 // CHECK-SAME: [[ARG_0:%.*]]: tensor<?xf32>, [[ARG_1:%.*]]: tensor<?xf32>)
 // CHECK: dim [[ARG_0]]
 // CHECK: dim [[ARG_1]]
+
+// -----
+
+func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> {
+  %0 = linalg.init_tensor [6, 5, %arg0] : tensor<6x5x?xf32>
+  %1 = linalg.tensor_reshape %0
+    [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>,
+     affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>,
+     affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] :
+     tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
+  return %1 : tensor<2x3x5x4x?x7xf32>
+}
+//      CHECK: func @init_tensor_reshape_expansion
+// CHECK-SAME:   %[[ARG0:.+]]: index
+//      CHECK:   %[[C28:.+]] = constant 28 : index
+//      CHECK:   %[[T0:.+]] = divi_unsigned %[[ARG0]], %[[C28]]
+//      CHECK:   %[[T1:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[T0]], 7]
+//      CHECK:   return %[[T1]]
+
+// -----
+
+func @init_tensor_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {
+  %0 = linalg.init_tensor [2, 3, 5, 4, %arg0, 7] : tensor<2x3x5x4x?x7xf32>
+  %1 = linalg.tensor_reshape %0
+    [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>,
+     affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>,
+     affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] :
+    tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32>
+  return %1 : tensor<6x5x?xf32>
+}
+//      CHECK: func @init_tensor_reshape_collapse
+// CHECK-SAME:   %[[ARG0:.+]]: index
+//      CHECK:   %[[C28:.+]] = constant 28 : index
+//      CHECK:   %[[T0:.+]] = muli %[[ARG0]], %[[C28]]
+//      CHECK:   %[[T1:.+]] = linalg.init_tensor [6, 5, %[[T0]]]
+//      CHECK:   return %[[T1]]


        


More information about the llvm-branch-commits mailing list