[Mlir-commits] [mlir] ff5de8a - [linalg][fusion] Disallow fusion when it would create an invalid expand_shape
Benjamin Kramer
llvmlistbot at llvm.org
Tue Jan 18 14:44:40 PST 2022
Author: Benjamin Kramer
Date: 2022-01-18T23:44:14+01:00
New Revision: ff5de8a9e0e5cb7f82b945486b784407f6aab8fe
URL: https://github.com/llvm/llvm-project/commit/ff5de8a9e0e5cb7f82b945486b784407f6aab8fe
DIFF: https://github.com/llvm/llvm-project/commit/ff5de8a9e0e5cb7f82b945486b784407f6aab8fe.diff
LOG: [linalg][fusion] Disallow fusion when it would create an invalid expand_shape
The input type of a linalg.generic can be less dynamic than its output
type. If this is the case moving a reshape across the generic op would
create invalid IR, as expand_shape cannot expand arbitrary dynamic
dimensions.
Check that the reshape is actually valid before creating the
expand_shape. This exposes the existing verification logic in reshape
utils and removes the incomplete custom implementation in fusion.
Differential Revision: https://reviews.llvm.org/D116600
Added:
Modified:
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
mlir/test/Dialect/Linalg/reshape_fusion.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 1c42c4b8542e..b2d4cf1e4bff 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -166,47 +166,19 @@ static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
/// 2) if a dimension in the collaped type is dynamic, one and only one of the
/// corresponding dimensions in the expanded type should be dynamic. This
/// rule is only needed with reshape operations that are expanding.
+LogicalResult reshapeLikeShapesAreCompatible(
+ function_ref<LogicalResult(const Twine &)> emitError,
+ ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
+ ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape);
+
template <typename OpTy>
static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
ShapedType expandedType,
bool isExpandingReshape) {
- ArrayRef<int64_t> collapsedShape = collapsedType.getShape();
- ArrayRef<int64_t> expandedShape = expandedType.getShape();
- unsigned expandedDimStart = 0;
- for (auto map : llvm::enumerate(op.getReassociationMaps())) {
- Optional<int64_t> dynamicShape;
- int64_t linearizedStaticShape = 1;
- for (auto dim : llvm::enumerate(expandedShape.slice(
- expandedDimStart, map.value().getNumResults()))) {
- if (ShapedType::isDynamic(dim.value())) {
- if (isExpandingReshape && dynamicShape) {
- return op->emitOpError("invalid to have a single dimension (")
- << map.index() << ") expanded into multiple dynamic dims ("
- << expandedDimStart + dynamicShape.getValue() << ","
- << expandedDimStart + dim.index() << ")";
- }
- dynamicShape = dim.index();
- } else {
- linearizedStaticShape *= dim.value();
- }
- }
- if (dynamicShape) {
- if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
- return op->emitOpError("expected dimension ")
- << map.index()
- << " of collapsed type to be dynamic since one or more of the "
- "corresponding dimensions in the expanded type is dynamic";
- }
- } else {
- if (collapsedShape[map.index()] != linearizedStaticShape) {
- return op->emitOpError("expected dimension ")
- << map.index() << " of collapsed type to be static value of "
- << linearizedStaticShape << " ";
- }
- }
- expandedDimStart += map.value().getNumResults();
- }
- return success();
+ return reshapeLikeShapesAreCompatible(
+ [&](const Twine &msg) { return op->emitOpError(msg); },
+ collapsedType.getShape(), expandedType.getShape(),
+ op.getReassociationIndices(), isExpandingReshape);
}
/// Pattern to collapse producer/consumer reshape ops that are both collapsing
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index ce636f3d84d3..619908ca5a38 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -608,31 +608,6 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
LogicalResult isGenericOpExpandable(GenericOp genericOp,
const ExpansionInfo &expansionInfo,
PatternRewriter &rewriter) {
- // Current reshape only supports expansion of a dynamic dim when only one of
- // the expanded dims are dynamic.
- for (const auto &originalShape :
- llvm::enumerate(expansionInfo.getOriginalShape()))
- if (ShapedType::isDynamic(originalShape.value())) {
- // All but one of the expanded dims must be static.
- bool foundDynamicExpandedDim = false;
- for (auto expandedShape :
- expansionInfo.getExpandedShapeOfDim(originalShape.index())) {
- if (ShapedType::isDynamic(expandedShape)) {
- if (foundDynamicExpandedDim) {
- return rewriter.notifyMatchFailure(
- genericOp,
- "cannot expanded dynamic dims into multiple dynamic dims");
- }
- foundDynamicExpandedDim = true;
- }
- }
- if (!foundDynamicExpandedDim) {
- return rewriter.notifyMatchFailure(
- genericOp, "dynamic dim expansion needs at least one dynamic dim "
- "in result shape");
- }
- }
-
if (!genericOp.hasIndexSemantics())
return success();
for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
@@ -793,13 +768,21 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
}
if (genericOp.isInputTensor(opOperand)) {
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
+ auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
RankedTensorType expandedOperandType =
- getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
- indexingMap, expansionInfo);
+ getExpandedType(opOperandType, indexingMap, expansionInfo);
if (expandedOperandType != opOperand->get().getType()) {
// Reshape the operand to get the right type.
SmallVector<ReassociationIndices> reassociation =
getReassociationForExpansion(indexingMap, expansionInfo);
+ if (failed(reshapeLikeShapesAreCompatible(
+ [&](const Twine &msg) {
+ return rewriter.notifyMatchFailure(genericOp, msg);
+ },
+ opOperandType.getShape(), expandedOperandType.getShape(),
+ reassociation,
+ /*isExpandingReshape=*/true)))
+ return llvm::None;
expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
genericOp.getLoc(), expandedOperandType, opOperand->get(),
reassociation));
@@ -813,12 +796,20 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
SmallVector<Value> outputs;
for (OpOperand *opOperand : genericOp.getOutputOperands()) {
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
+ auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
RankedTensorType expandedOutputType =
- getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
- indexingMap, expansionInfo);
+ getExpandedType(opOperandType, indexingMap, expansionInfo);
if (expandedOutputType != opOperand->get().getType()) {
SmallVector<ReassociationIndices> reassociation =
getReassociationForExpansion(indexingMap, expansionInfo);
+ if (failed(reshapeLikeShapesAreCompatible(
+ [&](const Twine &msg) {
+ return rewriter.notifyMatchFailure(genericOp, msg);
+ },
+ opOperandType.getShape(), expandedOutputType.getShape(),
+ reassociation,
+ /*isExpandingReshape=*/true)))
+ return llvm::None;
outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
genericOp.getLoc(), expandedOutputType, opOperand->get(),
reassociation));
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index b5136273d635..0048abee4194 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -276,3 +276,45 @@ bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
}
return true;
}
+
+LogicalResult mlir::reshapeLikeShapesAreCompatible(
+ function_ref<LogicalResult(const Twine &)> emitError,
+ ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
+ ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape) {
+ unsigned expandedDimStart = 0;
+ for (const auto &map : llvm::enumerate(reassociationMaps)) {
+ Optional<int64_t> dynamicShape;
+ int64_t linearizedStaticShape = 1;
+ for (const auto &dim : llvm::enumerate(
+ expandedShape.slice(expandedDimStart, map.value().size()))) {
+ if (ShapedType::isDynamic(dim.value())) {
+ if (isExpandingReshape && dynamicShape) {
+ return emitError("invalid to have a single dimension (" +
+ Twine(map.index()) +
+ ") expanded into multiple dynamic dims (" +
+ Twine(expandedDimStart + dynamicShape.getValue()) +
+ "," + Twine(expandedDimStart + dim.index()) + ")");
+ }
+ dynamicShape = dim.index();
+ } else {
+ linearizedStaticShape *= dim.value();
+ }
+ }
+ if (dynamicShape) {
+ if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
+ return emitError(
+ "expected dimension " + Twine(map.index()) +
+ " of collapsed type to be dynamic since one or more of the "
+ "corresponding dimensions in the expanded type is dynamic");
+ }
+ } else {
+ if (collapsedShape[map.index()] != linearizedStaticShape) {
+ return emitError("expected dimension " + Twine(map.index()) +
+ " of collapsed type to be static value of " +
+ Twine(linearizedStaticShape));
+ }
+ }
+ expandedDimStart += map.value().size();
+ }
+ return success();
+}
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 9582a4bbafc4..508726a4e60e 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -530,3 +530,30 @@ func @no_fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[RESHAPE]] : tensor<?xf32>)
// CHECK: return %[[GENERIC]]
+
+// -----
+
+func @no_fuse_mismatched_dynamism(%arg0: tensor<1x1xi64>, %arg1: tensor<?xi64>) -> tensor<1xi64> {
+ %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<1x1xi64> into tensor<1xi64>
+ %1 = linalg.init_tensor [1] : tensor<1xi64>
+ %2 = linalg.generic
+ {indexing_maps = [affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%0, %arg1 : tensor<1xi64>, tensor<?xi64>)
+ outs(%1 : tensor<1xi64>) {
+ ^bb0(%arg4: i64, %arg5: i64, %arg6: i64): // no predecessors
+ %3 = arith.addi %arg4, %arg5 : i64
+ linalg.yield %3 : i64
+ } -> tensor<1xi64>
+ return %2 : tensor<1xi64>
+}
+
+// CHECK: func @no_fuse_mismatched_dynamism
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1xi64>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<?xi64>
+// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[RESHAPE]], %[[ARG1]] : tensor<1xi64>, tensor<?xi64>)
+// CHECK: return %[[GENERIC]]
More information about the Mlir-commits
mailing list