[Mlir-commits] [mlir] cdf7b66 - [mlir][Linalg] Fix incorrect logic in deciding when to fuse reshapes by linearization.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jul 2 11:17:02 PDT 2021
Author: MaheshRavishankar
Date: 2021-07-02T11:16:21-07:00
New Revision: cdf7b661c24d037461492544996925dd5257911b
URL: https://github.com/llvm/llvm-project/commit/cdf7b661c24d037461492544996925dd5257911b
DIFF: https://github.com/llvm/llvm-project/commit/cdf7b661c24d037461492544996925dd5257911b.diff
LOG: [mlir][Linalg] Fix incorrect logic in deciding when to fuse reshapes by linearization.
Fusion by linearization should not happen when
- The reshape is expanding and it is a consumer
- The reshape is collapsing and is a producer.
The bug introduced in this logic by some recent refactoring resulted
in a crash.
To enforce this (negetive) use case, add a test that reproduces the
error and verifies the fix.
Differential Revision: https://reviews.llvm.org/D104970
Added:
mlir/test/Dialect/Linalg/reshape_linearization_fusion_with_unit_dims.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index c638294b1210..fc669b270576 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -75,6 +75,12 @@ def LinalgFoldReshapeOpsByLinearization :
let summary = "Fold TensorReshapeOps with generic/indexed generic ops by "
"linearization";
let constructor = "mlir::createFoldReshapeOpsByLinearizationPass()";
+ let options = [
+ Option<"allowFoldingUnitDimReshapes", "allow-folding-unit-dim-reshapes",
+ "bool", /*default=*/"false",
+ "Allow fusing linalg.tensor_reshape ops that performs unit "
+ "dimension collapsing">
+ ];
let dependentDialects = ["AffineDialect", "memref::MemRefDialect"];
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index f577f00c8b0f..20f13bd45107 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -320,27 +320,27 @@ fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
/// %0 = op ... : tensor<?x?x4x5xf32>
/// with output index_map
/// `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
+template <typename TensorReshapeOp>
static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
- ArrayRef<int64_t> sourceShape,
- ArrayRef<AffineMap> reassociationMaps) {
+ TensorReshapeOp reshapeOp) {
+ constexpr bool isExpanding =
+ std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value;
+ ArrayRef<int64_t> sourceShape =
+ (isExpanding ? reshapeOp.getResultType().getShape()
+ : reshapeOp.getSrcType().getShape());
SmallVector<AffineExpr> resultExprs;
- resultExprs.reserve(reassociationMaps.size());
ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
MLIRContext *context = sourceMap.getContext();
// Compute the result exprs based on the reassociation maps.
- for (AffineMap map : reassociationMaps) {
- ArrayRef<AffineExpr> collapsedDims = map.getResults();
+ for (auto &indices : reshapeOp.getReassociationIndices()) {
// Assume that they are in-order and contiguous (already checked in
// verifier).
- assert(!collapsedDims.empty());
- unsigned startDim =
- collapsedDims.front().cast<AffineDimExpr>().getPosition();
+ assert(!indices.empty());
SmallVector<int64_t> sizes;
SmallVector<AffineExpr> dimExprs;
- for (auto en :
- llvm::zip(sourceShape.slice(startDim, collapsedDims.size()),
- sourceExprs.slice(startDim, collapsedDims.size()))) {
+ for (auto en : llvm::zip(sourceShape.slice(indices[0], indices.size()),
+ sourceExprs.slice(indices[0], indices.size()))) {
if (std::get<0>(en) == 1)
continue;
sizes.push_back(std::get<0>(en));
@@ -359,7 +359,7 @@ static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
// divs in the indexing maps of the fused op which would make it non-invertible.
static bool isTensorReshapeOpFoldableByLinearization(
TensorExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) {
- if (!asProducer && expandOp.getResultType().hasStaticShape())
+ if (!asProducer)
return false;
return useIndexMap.isPermutation();
}
@@ -368,23 +368,26 @@ static bool isTensorReshapeOpFoldableByLinearization(
// consumer).
static bool isTensorReshapeOpFoldableByLinearization(
TensorCollapseShapeOp collapseOp, AffineMap useIndexMap, bool asProducer) {
- if (asProducer && collapseOp.getSrcType().hasStaticShape())
+ if (asProducer)
return false;
return useIndexMap.isPermutation();
}
/// Check if the reshape operation is only expansion into/collapsing of
/// unit-dimension.
-static bool isUnitDimExpansionOnly(ArrayRef<int64_t> expandedShape,
- ArrayRef<AffineMap> reassociation) {
- for (auto &map : reassociation) {
+template <typename TensorReshapeOp>
+static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) {
+ constexpr bool isExpanding =
+ std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value;
+ ArrayRef<int64_t> expandedShape =
+ (isExpanding ? reshapeOp.getResultType().getShape()
+ : reshapeOp.getSrcType().getShape());
+ for (auto &indices : reshapeOp.getReassociationIndices()) {
unsigned numUnitDims = 0;
- for (AffineExpr expr : map.getResults()) {
- unsigned position = expr.cast<AffineDimExpr>().getPosition();
+ for (int64_t position : indices)
if (expandedShape[position] == 1)
numUnitDims++;
- }
- if (numUnitDims != map.getNumResults() - 1)
+ if (numUnitDims != indices.size() - 1)
return false;
}
return true;
@@ -818,14 +821,10 @@ struct FoldProducerReshapeOpByLinearization
if (!reshapeOp)
continue;
- RankedTensorType returnType = reshapeOp.getResultType();
-
if (!isTensorReshapeOpFoldableByLinearization(
reshapeOp, genericOp.getTiedIndexingMap(en.value()),
/*asProducer =*/true) ||
- (foldUnitDimReshapesOnly &&
- !isUnitDimExpansionOnly(returnType.getShape(),
- reshapeOp.getReassociationMaps())))
+ (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
continue;
// Compute the fused operands list,
@@ -842,8 +841,10 @@ struct FoldProducerReshapeOpByLinearization
auto invMap = inversePermutation(fusedIndexMaps[en.index()]);
// Compute the indexing map to use for the result of the producer.
- AffineMap modifiedMap = linearizeCollapsedDims(
- invMap, returnType.getShape(), reshapeOp.getReassociationMaps());
+ AffineMap modifiedMap = linearizeCollapsedDims(invMap, reshapeOp);
+ // The modified map cannot have symbols.
+ if (modifiedMap.getNumSymbols())
+ return failure();
for (AffineExpr expr : modifiedMap.getResults()) {
if (!expr.isPureAffine())
return failure();
@@ -1081,9 +1082,7 @@ struct FoldConsumerReshapeOpByLinearization
reshapeOp,
producer.getTiedIndexingMap(producer.getOutputOperand(0)),
/*asProducer =*/false) ||
- (foldUnitDimReshapesOnly &&
- !isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
- reshapeOp.getReassociationMaps())))
+ (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
return failure();
// The indexing_maps for the operands of the fused operation are same as
// those for the operands of the producer.
@@ -1093,9 +1092,7 @@ struct FoldConsumerReshapeOpByLinearization
producer.getTiedIndexingMap(producer.getOutputOperand(0)));
// Compute the indexing map to use for the operand of the producer.
- AffineMap modifiedMap =
- linearizeCollapsedDims(invMap, reshapeOp.getSrcType().getShape(),
- reshapeOp.getReassociationMaps());
+ AffineMap modifiedMap = linearizeCollapsedDims(invMap, reshapeOp);
for (AffineExpr expr : modifiedMap.getResults()) {
if (!expr.isPureAffine()) {
return rewriter.notifyMatchFailure(
@@ -1144,8 +1141,7 @@ struct FoldReshapeWithGenericOpByExpansion
if (!producer || producer.getNumOutputs() != 1 ||
!isFusableWithReshapeByDimExpansion(producer,
producer.getOutputOperand(0)) ||
- isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(),
- reshapeOp.getReassociationMaps()))
+ isUnitDimExpansionOnly(reshapeOp))
return failure();
Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
producer, reshapeOp, producer.getOutputOperand(0), rewriter);
@@ -1248,12 +1244,10 @@ bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
const OpOperand &consumer) {
auto expandShapeOp = producer.getDefiningOp<linalg::TensorExpandShapeOp>();
if (expandShapeOp)
- return !isUnitDimExpansionOnly(expandShapeOp.getSrcType().getShape(),
- expandShapeOp.getReassociationMaps());
+ return !isUnitDimExpansionOnly(expandShapeOp);
auto collapseShapeOp =
producer.getDefiningOp<linalg::TensorCollapseShapeOp>();
- return !isUnitDimExpansionOnly(collapseShapeOp.getSrcType().getShape(),
- collapseShapeOp.getReassociationMaps());
+ return !isUnitDimExpansionOnly(collapseShapeOp);
}
namespace {
@@ -1312,6 +1306,9 @@ struct FoldReshapeOpsByLinearizationPass
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
populateFoldReshapeOpsByLinearizationPatterns(patterns);
+ if (allowFoldingUnitDimReshapes) {
+ populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
+ }
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
}
};
diff --git a/mlir/test/Dialect/Linalg/reshape_linearization_fusion_with_unit_dims.mlir b/mlir/test/Dialect/Linalg/reshape_linearization_fusion_with_unit_dims.mlir
new file mode 100644
index 000000000000..a4a27b5e747b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/reshape_linearization_fusion_with_unit_dims.mlir
@@ -0,0 +1,52 @@
+// RUN: mlir-opt -linalg-fold-reshape-ops-by-linearization=allow-folding-unit-dim-reshapes -split-input-file %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func @do_not_fold1(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?x1xf32>
+{
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+ %2 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
+ %3 = linalg.generic {
+ indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%2 : tensor<?x?xf32>) {
+ ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32):
+ %4 = addf %arg2, %arg3 : f32
+ linalg.yield %4 : f32
+ } -> tensor<?x?xf32>
+ %4 = linalg.tensor_expand_shape %3 [[0], [1, 2]] : tensor<?x?xf32> into tensor<?x?x1xf32>
+ return %4 : tensor<?x?x1xf32>
+}
+// CHECK-LABEL: func @do_not_fold1
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK: linalg.tensor_expand_shape %[[VAL]]
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func @do_not_fold2(%arg0 : tensor<?x?x1xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32>
+{
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = linalg.tensor_collapse_shape %arg0 [[0], [1, 2]] : tensor<?x?x1xf32> into tensor<?x?xf32>
+ %1 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
+ %2 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
+ %3 = linalg.init_tensor [%1, %2] : tensor<?x?xf32>
+ %4 = linalg.generic {
+ indexing_maps = [#map, #map, #map],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%3 : tensor<?x?xf32>) {
+ ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32):
+ %4 = addf %arg2, %arg3 : f32
+ linalg.yield %4 : f32
+ } -> tensor<?x?xf32>
+ return %4 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @do_not_fold2
+// CHECK: %[[VAL:.+]] = linalg.tensor_collapse_shape
+// CHECK: linalg.generic
+// CHECK-SAME: ins(%[[VAL]], %{{.+}} : tensor<?x?xf32>, tensor<?x?xf32>)
More information about the Mlir-commits
mailing list