[Mlir-commits] [mlir] f12639d - [mlir][Linalg] Avoid collapsing dimensions of linalg op that arent foldable.
Mahesh Ravishankar
llvmlistbot at llvm.org
Tue May 9 22:45:18 PDT 2023
Author: Mahesh Ravishankar
Date: 2023-05-10T05:45:02Z
New Revision: f12639d0d674327876388cfde6b4d226359284ac
URL: https://github.com/llvm/llvm-project/commit/f12639d0d674327876388cfde6b4d226359284ac
DIFF: https://github.com/llvm/llvm-project/commit/f12639d0d674327876388cfde6b4d226359284ac.diff
LOG: [mlir][Linalg] Avoid collapsing dimensions of linalg op that arent foldable.
The collapsing dimensions transformation is limited to only those
cases where the sequence of dimensions are contiguous in all the
ranges of the indexing maps of the operation. Add this check before
applying the transformation.
Differential Revision: https://reviews.llvm.org/D150176
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/test/Dialect/Linalg/collapse-dim.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 4a3a86ff6a326..b538d0f0bf7a4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -901,8 +901,21 @@ splitReductionByScaling(RewriterBase &b, LinalgOp op,
const ControlSplitReductionFn &controlSplitReductionFn,
bool useAlloc = false);
-/// Collapses dimensions of linalg.generic operation. It also collapses inputs
-/// before the op and expands outputs after the op.
+/// Return `true` if a given sequence of dimensions are contiguous in the
+/// range of the specified indexing map.
+bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence);
+/// Return `true` if all sequences of dimensions specified in `dimSequences` are
+/// contiguous in all the ranges of the `maps`.
+bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
+ ArrayRef<ReassociationIndices> dimSequences);
+
+/// Collapses dimensions of linalg.generic operation. A precondition to
+/// calling this method is that for each list in `foldedIterationDim`, the
+/// sequence of dimensions is contiguous in domains of all `indexing_maps` of
+/// the `genericOp`. This can be checked using `areDimSequencePreserved` method.
+/// When valid, the method also collapses the operands of the op. Returns
+/// replacement values of the results of the original `genericOp` by inserting
+/// reshapes to get back values of compatible types.
FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
RewriterBase &rewriter);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 57e6e2a6c81e4..bf728a6ec319b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1004,8 +1004,8 @@ getDomainReassociation(AffineMap indexingMap,
/// For a given `dimSequence`, check if the sequence is conserved in the
/// `indexingMap`. `indexingMap` is expected to be a projected permutation.
/// Non-existence of the sequence returns true as well.
-static bool isDimSequencePreserved(AffineMap indexingMap,
- ReassociationIndicesRef dimSequence) {
+bool mlir::linalg::isDimSequencePreserved(AffineMap indexingMap,
+ ReassociationIndicesRef dimSequence) {
assert(!dimSequence.empty() &&
"expected non-empty list for dimension sequence");
assert(indexingMap.isProjectedPermutation() &&
@@ -1045,6 +1045,15 @@ static bool isDimSequencePreserved(AffineMap indexingMap,
return true;
}
+bool mlir::linalg::areDimSequencesPreserved(
+ ArrayRef<AffineMap> maps, ArrayRef<ReassociationIndices> dimSequences) {
+ return llvm::all_of(maps, [&](AffineMap map) {
+ return llvm::all_of(dimSequences, [&](ReassociationIndicesRef dimSequence) {
+ return isDimSequencePreserved(map, dimSequence);
+ });
+ });
+}
+
// Return the list of dimensions of the iteration domain that can be
// collapsed to allow for fusion with the a producer that is an expand_shape
// operation. If all dimensions created by expansion can be collapsed in the
@@ -1592,6 +1601,13 @@ class CollapseLinalgDimensions : public OpRewritePattern<GenericOp> {
if (collapsableIterationDims.empty())
return failure();
+ // Check if the specified list of dimensions to collapse is a valid list.
+ if (!areDimSequencesPreserved(genericOp.getIndexingMapsArray(),
+ collapsableIterationDims)) {
+ return rewriter.notifyMatchFailure(
+ genericOp, "specified dimensions cannot be collapsed");
+ }
+
std::optional<SmallVector<Value>> replacements =
collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
rewriter);
diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir
index 0bd2bc1a99558..6737a6e15da5a 100644
--- a/mlir/test/Dialect/Linalg/collapse-dim.mlir
+++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir
@@ -53,3 +53,20 @@ func.func @collapse_parallel(
// CHECK-SAME: ins(%[[S]] : tensor<32x2x40960xf32>) outs(%[[D]] : tensor<2x32x40960xf32>) {
// CHECK: } -> tensor<2x32x40960xf32>
// CHECK: tensor.expand_shape %[[R]] {{\[}}[0], [1], [2, 3]] : tensor<2x32x40960xf32> into tensor<2x32x10x4096xf32>
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func.func @uncollapsable(%arg0 : tensor<41x3x1x57xf32>, %arg1 : tensor<3x1x57x41xf32>) -> tensor<3x1x57x41xf32> {
+ %0 = linalg.generic {
+ indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ ins(%arg0 : tensor<41x3x1x57xf32>) outs(%arg1 : tensor<3x1x57x41xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<3x1x57x41xf32>
+ return %0 : tensor<3x1x57x41xf32>
+}
+// CHECK-LABEL: func @uncollapsable(
+// CHECK: linalg.generic
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
More information about the Mlir-commits
mailing list