[llvm-branch-commits] [mlir] 430d43e - [mlir][Linalg] Disable fusion of tensor_reshape op by expansion when unit-dims are involved
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Jan 22 12:59:44 PST 2021
Author: MaheshRavishankar
Date: 2021-01-22T12:55:25-08:00
New Revision: 430d43e010bdd07d73c4d0d6536206d22d35a2cb
URL: https://github.com/llvm/llvm-project/commit/430d43e010bdd07d73c4d0d6536206d22d35a2cb
DIFF: https://github.com/llvm/llvm-project/commit/430d43e010bdd07d73c4d0d6536206d22d35a2cb.diff
LOG: [mlir][Linalg] Disable fusion of tensor_reshape op by expansion when unit-dims are involved
Fusion of generic/indexed_generic operations with tensor_reshape by
expansion when the latter just adds/removes unit-dimensions is
disabled since it just adds unit-trip count loops.
Differential Revision: https://reviews.llvm.org/D94626
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
mlir/test/Dialect/Linalg/reshape_fusion.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index d041df86d169..5d68328acc7e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -72,6 +72,15 @@ void populateFoldReshapeOpsByExpansionPatterns(
void populateFoldReshapeOpsByLinearizationPatterns(
MLIRContext *context, OwningRewritePatternList &patterns);
+/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
+/// producer (consumer) generic/indexed_generic operation by linearizing the
+/// indexing map used to access the source (target) of the reshape operation in
+/// the generic/indexed_generic operation. The patterns are applied only when
+/// the tensor reshape involved is collapsing (introducing) unit-extent
+/// dimensions.
+void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
+ MLIRContext *context, OwningRewritePatternList &patterns);
+
/// Patterns for fusing linalg operation on tensors.
void populateLinalgTensorOpsFusionPatterns(MLIRContext *context,
OwningRewritePatternList &patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 8d09d58b9d7a..3c7b2223ee49 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -497,6 +497,7 @@ void mlir::populateLinalgFoldUnitExtentDimsPatterns(
ReplaceUnitExtentTensors<IndexedGenericOp>>(context);
TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
patterns.insert<FoldReshapeOpWithUnitExtent>(context);
+ populateFoldUnitDimsReshapeOpsByLinearizationPatterns(context, patterns);
}
namespace {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 670d456ad2f2..0c5b8486824f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -302,9 +302,18 @@ static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
assert(!collapsedDims.empty());
unsigned startDim =
collapsedDims.front().cast<AffineDimExpr>().getPosition();
- AffineExpr linearizedExpr = makeCanonicalStridedLayoutExpr(
- sourceShape.slice(startDim, collapsedDims.size()),
- sourceExprs.slice(startDim, collapsedDims.size()), context);
+ SmallVector<int64_t, 4> sizes;
+ SmallVector<AffineExpr, 4> dimExprs;
+ for (auto en :
+ llvm::zip(sourceShape.slice(startDim, collapsedDims.size()),
+ sourceExprs.slice(startDim, collapsedDims.size()))) {
+ if (std::get<0>(en) == 1)
+ continue;
+ sizes.push_back(std::get<0>(en));
+ dimExprs.push_back(std::get<1>(en));
+ }
+ AffineExpr linearizedExpr =
+ makeCanonicalStridedLayoutExpr(sizes, dimExprs, context);
resultExprs.push_back(linearizedExpr);
}
return AffineMap::get(sourceMap.getNumDims(), sourceMap.getNumSymbols(),
@@ -349,6 +358,23 @@ static LinalgOp createLinalgOpOfSameType(LinalgOp op, PatternRewriter &rewriter,
return nullptr;
}
+/// Check if the reshape operation is only expansion into/collapsing of
+/// unit-dimension.
+static bool isUnitDimExpansionOnly(ArrayRef<int64_t> expandedShape,
+ ArrayRef<AffineMap> reassociation) {
+ for (auto &map : reassociation) {
+ unsigned numUnitDims = 0;
+ for (AffineExpr expr : map.getResults()) {
+ unsigned position = expr.cast<AffineDimExpr>().getPosition();
+ if (expandedShape[position] == 1)
+ numUnitDims++;
+ }
+ if (numUnitDims != map.getNumResults() - 1)
+ return false;
+ }
+ return true;
+}
+
/// Conditions for folding a generic/indexed-generic operation with a reshape op
/// by expanding the iteration space dimensionality for tensor operations. These
/// are preconditions assumed by `foldReshapeByDimExpansion` which implements
@@ -776,7 +802,7 @@ namespace {
/// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
/// ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
/// -> tensor<?x?x4x?xf32>
-template <typename LinalgOpTy>
+template <typename LinalgOpTy, bool foldUnitDimReshapesOnly>
struct FoldProducerReshapeOpByLinearization
: public OpRewritePattern<LinalgOpTy> {
using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
@@ -792,7 +818,10 @@ struct FoldProducerReshapeOpByLinearization
if (!reshapeOp ||
!isTensorReshapeOpFoldableByLinearization(
reshapeOp, linalgOp.getInputIndexingMap(operand.index()),
- /*asProducer =*/true))
+ /*asProducer =*/true) ||
+ (foldUnitDimReshapesOnly &&
+ !isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(),
+ reshapeOp.getReassociationMaps())))
continue;
// Compute the fused operands list,
@@ -858,7 +887,9 @@ struct FoldWithProducerReshapeOpByExpansion
// - All constraints of fusing with reshape by expansion are met.
if (reshapeOp.getSrcType().getRank() <
reshapeOp.getResultType().getRank() ||
- !isFusableWithReshapeByDimExpansion(linalgOp, operand.index()))
+ !isFusableWithReshapeByDimExpansion(linalgOp, operand.index()) ||
+ isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
+ reshapeOp.getReassociationMaps()))
continue;
Optional<SmallVector<Value, 1>> replacementValues =
@@ -877,6 +908,7 @@ struct FoldWithProducerReshapeOpByExpansion
/// Pattern to fold tensor_reshape op with its producer. The corresponding index
/// map in the consumer needs to be modified to linearize the folded dimension.
+template <bool foldUnitDimReshapesOnly>
struct FoldConsumerReshapeOpByLinearization
: public OpRewritePattern<TensorReshapeOp> {
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
@@ -888,7 +920,11 @@ struct FoldConsumerReshapeOpByLinearization
!isa<GenericOp, IndexedGenericOp>(producer.getOperation()) ||
!producer.hasTensorSemantics() || producer.getNumOutputs() != 1 ||
!isTensorReshapeOpFoldableByLinearization(
- reshapeOp, producer.getOutputIndexingMap(0), /*asProducer =*/false))
+ reshapeOp, producer.getOutputIndexingMap(0),
+ /*asProducer =*/false) ||
+ (foldUnitDimReshapesOnly &&
+ !isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
+ reshapeOp.getReassociationMaps())))
return failure();
// The indexing_maps for the operands of the fused operation are same as
// those for the operands of the producer.
@@ -949,7 +985,10 @@ struct FoldReshapeWithGenericOpByExpansion
return failure();
LinalgOp producer = reshapeOp.src().getDefiningOp<LinalgOp>();
if (!producer || producer.getNumOutputs() != 1 ||
- !isFusableWithReshapeByDimExpansion(producer, producer.getNumInputs()))
+ !isFusableWithReshapeByDimExpansion(producer,
+ producer.getNumInputs()) ||
+ isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(),
+ reshapeOp.getReassociationMaps()))
return failure();
Optional<SmallVector<Value, 1>> replacementValues =
fuseWithReshapeByExpansion(producer, reshapeOp, producer.getNumInputs(),
@@ -1098,9 +1137,16 @@ struct FoldReshapeOpsByLinearizationPass
void mlir::populateFoldReshapeOpsByLinearizationPatterns(
MLIRContext *context, OwningRewritePatternList &patterns) {
- patterns.insert<FoldProducerReshapeOpByLinearization<GenericOp>,
- FoldProducerReshapeOpByLinearization<IndexedGenericOp>,
- FoldConsumerReshapeOpByLinearization>(context);
+ patterns.insert<FoldProducerReshapeOpByLinearization<GenericOp, false>,
+ FoldProducerReshapeOpByLinearization<IndexedGenericOp, false>,
+ FoldConsumerReshapeOpByLinearization<false>>(context);
+}
+
+void mlir::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
+ MLIRContext *context, OwningRewritePatternList &patterns) {
+ patterns.insert<FoldProducerReshapeOpByLinearization<GenericOp, true>,
+ FoldProducerReshapeOpByLinearization<IndexedGenericOp, true>,
+ FoldConsumerReshapeOpByLinearization<true>>(context);
}
void mlir::populateFoldReshapeOpsByExpansionPatterns(
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 17b8bda967b1..d40a91667500 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -331,3 +331,26 @@ func @fold_reshape(%arg0: tensor<2xf32>) -> tensor<2x1xf32>
] : tensor<2x1x1xf32> into tensor<2x1xf32>
return %1 : tensor<2x1xf32>
}
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map4 = affine_map<(d0, d1, d2) -> (d2)>
+func @fold_unit_dim_tensor_reshape_op(%arg0 : tensor<5xf32>) -> tensor<2x5xf32>
+{
+ %1 = linalg.init_tensor [1, 2, 5] : tensor<1x2x5xf32>
+ %2 = linalg.generic {i64, indexing_maps = [#map1, #map0],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%arg0 : tensor<5xf32>) outs(%1 : tensor<1x2x5xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32): // no predecessors
+ linalg.yield %arg1 : f32
+ } -> tensor<1x2x5xf32>
+ %3 = linalg.tensor_reshape %2 [#map3, #map4]
+ : tensor<1x2x5xf32> into tensor<2x5xf32>
+ return %3 : tensor<2x5xf32>
+}
+// CHECK-LABEL: func @fold_unit_dim_tensor_reshape_op
+// CHECK: %[[RESULT:.+]] = linalg.generic
+// CHECK: return %[[RESULT]]
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 447917548c5c..50269e36751b 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -188,42 +188,6 @@ func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
// -----
-func @scalar_reshape(
- %arg0 : tensor<1x10xf32>, %arg1 : tensor<1xf32>) -> tensor<1x10xf32>
-{
- %0 = linalg.tensor_reshape %arg1 [] : tensor<1xf32> into tensor<f32>
- %1 = linalg.init_tensor [10] : tensor<10xf32>
- %2 = linalg.generic
- {indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>],
- iterator_types = ["parallel"]}
- ins(%0 : tensor<f32>)
- outs(%1 : tensor<10xf32>) {
- ^bb0(%arg2: f32, %s: f32): // no predecessors
- linalg.yield %arg2 : f32
- } -> tensor<10xf32>
- %3 = linalg.tensor_reshape %2 [affine_map<(d0, d1) -> (d0, d1)>]
- : tensor<10xf32> into tensor<1x10xf32>
- return %3 : tensor<1x10xf32>
-}
-
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> ()>
-// CHECK: func @scalar_reshape
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x10xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<1xf32>
-// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG1]] []
-// CHECK-SAME: tensor<1xf32> into tensor<f32>
-// CHECK: %[[T1:.+]] = linalg.init_tensor [10]
-// CHECK: %[[T2:.+]] = linalg.tensor_reshape %[[T1]] [#[[MAP0]]]
-// CHECK: %[[T3:.+]] = linalg.generic
-// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP0]]]
-// CHECK-SAME: iterator_types = ["parallel", "parallel"]
-// CHECK-SAME: ins(%[[T0]] : tensor<f32>)
-// CHECK-SAME: outs(%[[T2]] : tensor<1x10xf32>)
-// CHECK: return %[[T3]] : tensor<1x10xf32>
-
-// -----
-
#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
@@ -336,7 +300,7 @@ func @reshape_as_consumer_permutation
%5 = addi %3, %4 : i32
%6 = index_cast %arg2 : index to i32
%7 = addi %5, %6 : i32
- linalg.yield %7 : i32
+ linalg.yield %7 : i32
} -> tensor<6x4x210xi32>
%d = linalg.tensor_reshape %c
[affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>,
@@ -493,3 +457,77 @@ func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
// CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor<?x4x5x?xf32>, tensor<?x4x5x?xf32>)
// CHECK-SAME: outs(%[[T2]] : tensor<?x?x4x5xf32>)
// CHECK: return %[[T3]] : tensor<?x?x4x5xf32>
+
+// -----
+
+func @unit_dim_reshape_expansion(%arg0 : tensor<1x5xf32>) -> tensor<5x5xf32> {
+ %0 = linalg.tensor_reshape %arg0
+ [affine_map<(d0, d1) -> (d0, d1)>] : tensor<1x5xf32> into tensor<5xf32>
+ %1 = linalg.init_tensor [5, 5] : tensor<5x5xf32>
+ %2 = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1) -> (d0)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%0 : tensor<5xf32>) outs(%1 : tensor<5x5xf32>) {
+ ^bb0(%arg2: f32, %arg3: f32): // no predecessors
+ linalg.yield %arg2 : f32
+ } -> tensor<5x5xf32>
+ return %2 : tensor<5x5xf32>
+}
+// CHECK: func @unit_dim_reshape_expansion
+// CHECK-DAG: linalg.tensor_reshape
+// CHECK-DAG: linalg.init_tensor
+// CHECK: linalg.generic
+
+// -----
+
+func @unit_dim_reshape_collapse(%arg0 : tensor<5xf32>) -> tensor<5x1x5xf32> {
+ %0 = linalg.init_tensor [5, 5] : tensor<5x5xf32>
+ %1 = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1) -> (d0)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : tensor<5xf32>) outs(%0 : tensor<5x5xf32>) {
+ ^bb0(%arg2: f32, %arg3: f32): // no predecessors
+ linalg.yield %arg2 : f32
+ } -> tensor<5x5xf32>
+ %2 = linalg.tensor_reshape %1
+ [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
+ : tensor<5x5xf32> into tensor<5x1x5xf32>
+ return %2 : tensor<5x1x5xf32>
+}
+// CHECK: func @unit_dim_reshape_collapse
+// CHECK: linalg.init_tensor
+// CHECK: linalg.generic
+// CHECK: linalg.tensor_reshape
+
+// -----
+
+func @unit_dim_reshape_expansion_full
+ (%arg0 : tensor<1x?x1x2x1x4xf32>, %arg1 : tensor<?x2x4xf32>)
+ -> tensor<?x2x4xf32> {
+ %c1 = constant 1 : index
+ %0 = linalg.tensor_reshape %arg0
+ [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4, d5) -> (d5)>]
+ : tensor<1x?x1x2x1x4xf32> into tensor<?x2x4xf32>
+ %1 = dim %arg0, %c1 : tensor<1x?x1x2x1x4xf32>
+ %2 = linalg.init_tensor [%1, 2, 4] : tensor<?x2x4xf32>
+ %3 = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%0, %arg1 : tensor<?x2x4xf32>, tensor<?x2x4xf32>)
+ outs(%2 : tensor<?x2x4xf32>) {
+ ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors
+ %4 = mulf %arg2, %arg3 : f32
+ linalg.yield %4 : f32
+ } -> tensor<?x2x4xf32>
+ return %3 : tensor<?x2x4xf32>
+}
+// CHECK: func @unit_dim_reshape_expansion_full
+// CHECK-DAG: linalg.tensor_reshape
+// CHECK-DAG: linalg.init_tensor
+// CHECK: linalg.generic
More information about the llvm-branch-commits
mailing list