[llvm-branch-commits] [mlir] 9c0dc0b - [mlir][Linalg] Fold init_tensor -> linalg.tensor_reshape.
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Jan 11 09:27:03 PST 2021
Author: MaheshRavishankar
Date: 2021-01-11T09:22:35-08:00
New Revision: 9c0dc0b2c1cc973056237bdd80dbba749941ea63
URL: https://github.com/llvm/llvm-project/commit/9c0dc0b2c1cc973056237bdd80dbba749941ea63
DIFF: https://github.com/llvm/llvm-project/commit/9c0dc0b2c1cc973056237bdd80dbba749941ea63.diff
LOG: [mlir][Linalg] Fold init_tensor -> linalg.tensor_reshape.
Reshaping an init_tensor can be folded to a init_tensor op of the
final type.
Differential Revision: https://reviews.llvm.org/D93773
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 8a97753e1a5c..8732065bb042 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -718,9 +718,123 @@ struct ReplaceDimOfInitTensorOp : public OpRewritePattern<DimOp> {
};
} // namespace
+static Value getCollapsedInitTensor(OpBuilder &builder,
+ TensorReshapeOp reshapeOp) {
+ Location loc = reshapeOp.getLoc();
+ SmallVector<Value, 4> dynamicShapes;
+ SmallVector<int64_t, 4> staticShapes;
+ auto reassociation = reshapeOp.getReassociationMaps();
+ Value src = reshapeOp.src();
+ RankedTensorType srcType = reshapeOp.getSrcType();
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+ for (auto map : reassociation) {
+ Value linearizedDynamicDim = nullptr;
+ int64_t linearizedStaticDim = 1;
+ for (unsigned i : llvm::map_range(map.getResults(), [](AffineExpr e) {
+ return e.cast<AffineDimExpr>().getPosition();
+ })) {
+ if (ShapedType::isDynamic(srcShape[i])) {
+ Value shapeVal = builder.create<DimOp>(loc, src, i);
+ if (linearizedDynamicDim) {
+ linearizedDynamicDim =
+ builder.create<MulIOp>(loc, linearizedDynamicDim, shapeVal);
+ } else {
+ linearizedDynamicDim = shapeVal;
+ }
+ } else {
+ linearizedStaticDim *= srcShape[i];
+ }
+ }
+ if (linearizedDynamicDim) {
+ if (linearizedStaticDim != 1) {
+ linearizedDynamicDim = builder.create<MulIOp>(
+ loc, linearizedDynamicDim,
+ builder.create<ConstantIndexOp>(loc, linearizedStaticDim));
+ }
+ dynamicShapes.push_back(linearizedDynamicDim);
+ staticShapes.push_back(ShapedType::kDynamicSize);
+ } else {
+ staticShapes.push_back(linearizedStaticDim);
+ }
+ }
+ return builder.create<InitTensorOp>(loc, dynamicShapes, staticShapes,
+ srcType.getElementType());
+}
+
+static Value getExpandedInitTensor(OpBuilder &builder,
+ TensorReshapeOp reshapeOp) {
+ SmallVector<Value, 4> dynamicShapes;
+ SmallVector<int64_t, 4> staticShapes;
+ auto reassociation = reshapeOp.getReassociationMaps();
+ Value src = reshapeOp.src();
+ RankedTensorType srcType = reshapeOp.getSrcType();
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+ ArrayRef<int64_t> dstShape = reshapeOp.getResultType().getShape();
+ Location loc = reshapeOp.getLoc();
+ for (auto map : enumerate(reassociation)) {
+ int64_t linearizedStaticDim = 1;
+ bool hasDynamic = false;
+ for (unsigned i :
+ llvm::map_range(map.value().getResults(), [](AffineExpr e) {
+ return e.cast<AffineDimExpr>().getPosition();
+ })) {
+ if (ShapedType::isDynamic(dstShape[i])) {
+ // Only one of the dimensions of the expanded shape should be dynamic.
+ if (hasDynamic)
+ return nullptr;
+ hasDynamic = true;
+ staticShapes.push_back(ShapedType::kDynamicSize);
+ continue;
+ }
+ staticShapes.push_back(dstShape[i]);
+ linearizedStaticDim *= dstShape[i];
+ }
+ if (hasDynamic) {
+ // If the expanded dimensions has a dynamic shape, the src shape must be
+ // dynamic as well.
+ if (!ShapedType::isDynamic(srcShape[map.index()]))
+ return nullptr;
+ Value dynamicDim = builder.create<DimOp>(loc, src, map.index());
+ if (linearizedStaticDim != 1) {
+ dynamicDim = builder.create<UnsignedDivIOp>(
+ loc, dynamicDim,
+ builder.create<ConstantIndexOp>(loc, linearizedStaticDim));
+ }
+ dynamicShapes.push_back(dynamicDim);
+ }
+ }
+ return builder.create<InitTensorOp>(loc, dynamicShapes, staticShapes,
+ srcType.getElementType());
+}
+
+namespace {
+struct FoldWithTensorReshapeOp : public OpRewritePattern<TensorReshapeOp> {
+ using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
+ PatternRewriter &rewriter) const override {
+ if (!reshapeOp.src().getDefiningOp<InitTensorOp>())
+ return failure();
+ RankedTensorType collapsedType = reshapeOp.getSrcType();
+ RankedTensorType expandedType = reshapeOp.getResultType();
+ bool isCollapsed = expandedType.getRank() < collapsedType.getRank();
+ if (isCollapsed)
+ std::swap(collapsedType, expandedType);
+ Value initTensorOp = isCollapsed
+ ? getCollapsedInitTensor(rewriter, reshapeOp)
+ : getExpandedInitTensor(rewriter, reshapeOp);
+ if (!initTensorOp)
+ return failure();
+ rewriter.replaceOp(reshapeOp, initTensorOp);
+ return success();
+ }
+};
+} // namespace
+
void InitTensorOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<ReplaceDimOfInitTensorOp, ReplaceStaticShapeDims>(context);
+ results.insert<FoldWithTensorReshapeOp, ReplaceDimOfInitTensorOp,
+ ReplaceStaticShapeDims>(context);
}
//===----------------------------------------------------------------------===//
@@ -1043,23 +1157,23 @@ static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
ArrayRef<int64_t> expandedShape = expandedType.getShape();
unsigned expandedDimStart = 0;
for (auto map : llvm::enumerate(op.getReassociationMaps())) {
- Optional<int64_t> dynamicDims;
+ Optional<int64_t> dynamicShape;
int64_t linearizedStaticShape = 1;
for (auto dim : llvm::enumerate(expandedShape.slice(
expandedDimStart, map.value().getNumResults()))) {
if (ShapedType::isDynamic(dim.value())) {
- if (isExpandingReshape && dynamicDims) {
+ if (isExpandingReshape && dynamicShape) {
return op->emitOpError("invalid to have a single dimension (")
<< map.index() << ") expanded into multiple dynamic dims ("
- << expandedDimStart + dynamicDims.getValue() << ","
+ << expandedDimStart + dynamicShape.getValue() << ","
<< expandedDimStart + dim.index() << ")";
}
- dynamicDims = dim.index();
+ dynamicShape = dim.index();
} else {
linearizedStaticShape *= dim.value();
}
}
- if (dynamicDims) {
+ if (dynamicShape) {
if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
return op->emitOpError("expected dimension ")
<< map.index()
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 4102a1326b96..6b806c801341 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -413,3 +413,39 @@ func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,
// CHECK-SAME: [[ARG_0:%.*]]: tensor<?xf32>, [[ARG_1:%.*]]: tensor<?xf32>)
// CHECK: dim [[ARG_0]]
// CHECK: dim [[ARG_1]]
+
+// -----
+
+func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> {
+ %0 = linalg.init_tensor [6, 5, %arg0] : tensor<6x5x?xf32>
+ %1 = linalg.tensor_reshape %0
+ [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] :
+ tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
+ return %1 : tensor<2x3x5x4x?x7xf32>
+}
+// CHECK: func @init_tensor_reshape_expansion
+// CHECK-SAME: %[[ARG0:.+]]: index
+// CHECK: %[[C28:.+]] = constant 28 : index
+// CHECK: %[[T0:.+]] = divi_unsigned %[[ARG0]], %[[C28]]
+// CHECK: %[[T1:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[T0]], 7]
+// CHECK: return %[[T1]]
+
+// -----
+
+func @init_tensor_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {
+ %0 = linalg.init_tensor [2, 3, 5, 4, %arg0, 7] : tensor<2x3x5x4x?x7xf32>
+ %1 = linalg.tensor_reshape %0
+ [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] :
+ tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32>
+ return %1 : tensor<6x5x?xf32>
+}
+// CHECK: func @init_tensor_reshape_collapse
+// CHECK-SAME: %[[ARG0:.+]]: index
+// CHECK: %[[C28:.+]] = constant 28 : index
+// CHECK: %[[T0:.+]] = muli %[[ARG0]], %[[C28]]
+// CHECK: %[[T1:.+]] = linalg.init_tensor [6, 5, %[[T0]]]
+// CHECK: return %[[T1]]
More information about the llvm-branch-commits
mailing list