[Mlir-commits] [mlir] b40e901 - [mlir][Linalg] Allow collapsing subset of the reassociations when fusing by collapsing.
Mahesh Ravishankar
llvmlistbot at llvm.org
Tue Apr 12 11:56:42 PDT 2022
Author: Mahesh Ravishankar
Date: 2022-04-12T18:56:32Z
New Revision: b40e901333b903fd71f17c3314d3e40f8abde074
URL: https://github.com/llvm/llvm-project/commit/b40e901333b903fd71f17c3314d3e40f8abde074
DIFF: https://github.com/llvm/llvm-project/commit/b40e901333b903fd71f17c3314d3e40f8abde074.diff
LOG: [mlir][Linalg] Allow collapsing subset of the reassociations when fusing by collapsing.
This change generalizes the fusion of `tensor.expand_shape` ->
`linalg.generic` op by collapsing to handle cases where only a subset
of the reassociations specified in the `tensor.expand_shape` are valid
to be collapsed.
The method that does the collapsing is refactored to allow it to be a
generic utility when required.
Reviewed By: gysit
Differential Revision: https://reviews.llvm.org/D123153
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 45d57c378e4d0..b97e654bae4b5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1151,43 +1151,24 @@ struct FoldReshapeWithGenericOpByExpansion
// contraction of dimensions.
//===---------------------------------------------------------------------===//
-/// For an `indexingMap` that is a projected permutation, if the range is to be
-/// collapsed using the given `reassociation`, get the reassociation in the
-/// domain that would keep the map a projected permutation.
-static SmallVector<ReassociationIndices>
+/// For a given list of indices in the range of the `indexingMap` that are
+/// folded, return the indices of the corresponding domain. Return `llvm::None`
+/// on failure. Ensures that all the elements of the returned reassociation are
+/// distinct.
+static ReassociationIndices
getDomainReassociation(AffineMap indexingMap,
- ArrayRef<ReassociationIndices> rangeReassociation) {
+ ReassociationIndicesRef rangeReassociation) {
assert(indexingMap.isProjectedPermutation() &&
- "expected projected permutation map");
- unsigned counter = 0;
- SmallVector<ReassociationIndices> domainReassociation;
- llvm::SmallDenseSet<unsigned, 4> processedDomainDims;
- // Iterate over the reassociation indices.
- for (ReassociationIndicesRef foldedRangeDims : rangeReassociation) {
- ReassociationIndices foldedDomainDims;
- for (auto rangeDim : foldedRangeDims) {
- (void)rangeDim;
- AffineDimExpr dimExpr =
- indexingMap.getResult(counter++).cast<AffineDimExpr>();
- foldedDomainDims.push_back(dimExpr.getPosition());
- processedDomainDims.insert(dimExpr.getPosition());
- }
- domainReassociation.emplace_back(std::move(foldedDomainDims));
- }
- // Fill in the missing domain dims.
- for (auto dim : llvm::seq<unsigned>(0, indexingMap.getNumDims())) {
- if (processedDomainDims.count(dim))
- continue;
- ReassociationIndices vec = {dim};
- domainReassociation.emplace_back(std::move(vec));
- }
+ "expected projected permutation");
- // Sort the reassociation using the first dimension of the folded range to
- // not create unnecessary transposes.
- llvm::sort(domainReassociation,
- [](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) {
- return lhs[0] < rhs[0];
- });
+ ReassociationIndices domainReassociation = llvm::to_vector<4>(
+ llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
+ return indexingMap.getResults()[pos]
+ .cast<AffineDimExpr>()
+ .getPosition();
+ }));
+ // The projected permutation semantics ensures that there is no repetition of
+ // the domain indices.
return domainReassociation;
}
@@ -1235,100 +1216,238 @@ static bool isDimSequencePreserved(AffineMap indexingMap,
return true;
}
-// Check if a generic op can be fused along an operand by collapsing dimensions.
-static bool isFusableWithReshapeByDimCollapse(
- GenericOp genericOp, OpOperand *fusableOperand,
- ArrayRef<ReassociationIndices> reassociation) {
+// Return the list of dimensions of the iteration domain that can be
+// collapsed to allow for fusion with the a producer that is an expand_shape
+// operation. If all dimensions created by expansion can be collapsed in the
+// iteration space then the reshape is defunct.
+//
+// Example:
+//
+// ```mlir
+// #map = affine_map<(d0, d1) -> (d0, d1)>
+// %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
+// %2 = linalg.init_tensor [..] : tensor<?x4xf32>
+// %3 = linalg.generic {
+// indexing_maps = [#map, #map],
+// iterator_types = ["parallel" ,"parallel"]}
+// ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
+// ```
+//
+// can be fused by collapsing the dimensions of the iteration space.
+//
+// ```mlir
+// #map = affine_map<(d0) -> (d0)>
+// %2 = linalg.init_tensor [..] : tensor<?xf32>
+// %3 = linalg.generic {
+// indexing_maps = [#map, #map],
+// iterator_types = ["parallel"]}
+// ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
+// %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
+// ```
+//
+// In the following example,
+//
+// ```mlir
+// #map0 = affine_map<(d0, d1) -> (d0, d1)>
+// #map1 = affine_map<(d0, d1) -> (d1, d0)>
+// %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
+// %2 = linalg.init_tensor [..] : tensor<4x?xf32>
+// %2 = linalg.generic {
+// indexing_maps = [#map0, #map1],
+// iterator_types = ["parallel" ,"parallel"]}
+// ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
+// ```
+//
+// the reshape cannot be fused with the generic op by collapsing the op
+// dimensions since the indexing maps will have to contain mods and divs
+// to preserve the accesses pattern. When no dimensions of the iteration
+// space are collapsable and empty vector is returned.
+static SmallVector<ReassociationIndices>
+getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
+ ArrayRef<ReassociationIndices> reassociation) {
// Some basic checks for this fusion to be valid.
if (!genericOp.hasTensorSemantics() || genericOp.getNumOutputs() != 1)
- return false;
+ return {};
if (!llvm::all_of(genericOp.getIndexingMaps(), [](AffineMap map) {
return map.isProjectedPermutation();
})) {
- return false;
+ return {};
}
- // Get the reassociation for the iteration space.
- SmallVector<ReassociationIndices> iterationReassociation =
- getDomainReassociation(genericOp.getTiedIndexingMap(fusableOperand),
- reassociation);
- if (iterationReassociation.empty()) {
- // If the domain reassociation indices is empty, then this is a scalar op.
- // Nothing to do.
- return false;
+ // Compute all the loops with the reduction iterator types.
+ SmallVector<int64_t> reductionDims;
+ for (auto iteratorType : llvm::enumerate(genericOp.iterator_types())) {
+ if (isReductionIterator(iteratorType.value())) {
+ reductionDims.push_back(iteratorType.index());
+ }
}
+ llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
+ AffineMap indexingMap = genericOp.getTiedIndexingMap(fusableOperand);
auto iteratorTypes = genericOp.iterator_types().getValue();
- ArrayRef<Attribute> iteratorTypesRef(iteratorTypes);
- for (ReassociationIndicesRef foldedIterDims : iterationReassociation) {
- // Check that for all indexing maps, the folded dimensions sequence is
- // preserved.
- if (!llvm::all_of(genericOp.getIndexingMaps(), [&](AffineMap indexingMap) {
- return isDimSequencePreserved(indexingMap, foldedIterDims);
+ SmallVector<ReassociationIndices> iterationSpaceReassociation;
+ for (ReassociationIndicesRef foldedRangeDims : reassociation) {
+ assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
+
+ // Ignore dims that are not folded.
+ if (foldedRangeDims.size() == 1)
+ continue;
+
+ ReassociationIndices foldedIterationSpaceDims =
+ getDomainReassociation(indexingMap, foldedRangeDims);
+
+ // Check that the folded iteration dims do not contain already processed
+ // dims.
+ if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
+ return processedIterationDims.count(dim);
}))
- return false;
- unsigned startDim = foldedIterDims[0];
- ArrayRef<Attribute> foldedIteratorTypes =
- iteratorTypesRef.drop_front(startDim).take_front(foldedIterDims.size());
- // Check that all folded iterator types are either all parallel, or all
- // reduction.
- if (!llvm::all_of(
- foldedIteratorTypes,
- [](Attribute attr) { return isParallelIterator(attr); }) &&
- !llvm::all_of(foldedIteratorTypes,
- [](Attribute attr) { return isReductionIterator(attr); }))
- return false;
+ continue;
+
+ // Check that all folded iterator types are all parallel or all reductions.
+ Attribute startIteratorType = iteratorTypes[foldedIterationSpaceDims[0]];
+ if (!isParallelIterator(startIteratorType) &&
+ !isReductionIterator(startIteratorType))
+ continue;
+ if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
+ return iteratorTypes[dim] != startIteratorType;
+ }))
+ continue;
+
+ // If the folded dimensions correspond to a "reduction" iterator type,
+ // the folded dimensions need to be "in-order". Strictly speaking this is
+ // not necessary, for reductions that are associative and commutative, but
+ // using a more strict definition of reduction for now.
+ if (isReductionIterator(startIteratorType)) {
+ bool isContiguous = false;
+ for (auto startDim : llvm::enumerate(reductionDims)) {
+ // Move window in `reductionDims` to start of the folded iteration dims.
+ if (startDim.value() != foldedIterationSpaceDims[0])
+ continue;
+ // If sizes doesnt match, trivial not contiguous. This condition should
+ // not be hit.
+ if (startDim.index() + foldedIterationSpaceDims.size() >
+ reductionDims.size())
+ break;
+ // Check that the contiguity is maintained.
+ isContiguous = true;
+ for (auto foldedDim : llvm::enumerate(foldedIterationSpaceDims)) {
+ if (reductionDims[foldedDim.index() + startDim.index()] !=
+ foldedDim.value()) {
+ isContiguous = false;
+ break;
+ }
+ }
+ break;
+ }
+ if (!isContiguous)
+ continue;
+ }
+
+ // Check that the sequence is preserved in all indexing maps.
+ if (llvm::any_of(genericOp.getIndexingMaps(), [&](AffineMap indexingMap) {
+ return !isDimSequencePreserved(indexingMap, foldedIterationSpaceDims);
+ }))
+ continue;
+
+ processedIterationDims.insert(foldedIterationSpaceDims.begin(),
+ foldedIterationSpaceDims.end());
+ iterationSpaceReassociation.emplace_back(
+ std::move(foldedIterationSpaceDims));
}
- return true;
+
+ return iterationSpaceReassociation;
}
/// Helper class to carry state while collapsing the `linalg.generic` op.
namespace {
class CollapsingInfo {
public:
- CollapsingInfo(SmallVector<ReassociationIndices> &&reassociation) {
- iterationReassociation = std::move(reassociation);
- for (const auto &foldedIterDims : enumerate(iterationReassociation)) {
- foldedDimStartToSequenceMap[foldedIterDims.value()[0]] =
- foldedIterDims.index();
+ LogicalResult initialize(unsigned origNumLoops,
+ ArrayRef<ReassociationIndices> foldedIterationDims) {
+ llvm::SmallDenseSet<int64_t, 4> processedDims;
+ // Find all the dims that are folded.
+ for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
+ if (foldedIterationDim.empty())
+ continue;
+ // If the folded dims contain dims already folded, that's illegal
+ // specification. Repetition within a list is also illegal.
+ for (auto dim : foldedIterationDim) {
+ if (dim >= origNumLoops)
+ return failure();
+ if (processedDims.count(dim))
+ return failure();
+ processedDims.insert(dim);
+ }
+ collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
+ foldedIterationDim.end());
+ }
+ if (processedDims.size() > origNumLoops)
+ return failure();
+
+ // Add all the preserved dims of the original op as single
+ // elements to `collapsedOpToOrigOpIterationDim`.
+ for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
+ if (processedDims.count(dim))
+ continue;
+ collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim});
}
- }
- // Returns the iteration space reassociation.
- ArrayRef<ReassociationIndices> getReassociationIndices() {
- return iterationReassociation;
+ llvm::sort(collapsedOpToOrigOpIterationDim,
+ [&](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) {
+ return lhs[0] < rhs[0];
+ });
+ origOpToCollapsedOpIterationDim.resize(origNumLoops);
+ for (auto foldedDims : llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
+ for (auto dim : enumerate(foldedDims.value()))
+ origOpToCollapsedOpIterationDim[dim.value()] =
+ std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
+ }
+ return success();
}
- // Returns true if the given dimension is the start of a sequence of folded
- // dimensions.
- bool isDimStartOfFoldedDims(unsigned dim) {
- return foldedDimStartToSequenceMap.count(dim);
+ /// Return mapping from collapsed loop domain to original loop domain.
+ ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
+ return collapsedOpToOrigOpIterationDim;
}
- // Return the folded dimensions starting at `dim`.
- ReassociationIndicesRef getFoldedDimsStartingAt(unsigned dim) {
- assert(foldedDimStartToSequenceMap.count(dim) &&
- "invalid start dim of folded dim "
- "sequence");
- return iterationReassociation[foldedDimStartToSequenceMap[dim]];
+ /// Return mapping from original loop domain to collapsed loop domain. The
+ /// mapping is a pair. First value is the dimension in the collapsed loop that
+ /// the original loop is mapped to. Second is the relative position in folded
+ /// list of this domain. For example if the original loop domain is 3D, and
+ /// the collapsed loop domain is folding all of it, i.e.
+ ///
+ /// ```
+ /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
+ /// ```
+ ///
+ /// then
+ ///
+ /// ```
+ /// origOpToCollapsedOpMapping[0] = {0, 0};
+ /// origOpToCollapsedOpMapping[1] = {0, 1};
+ /// origOpToCollapsedOpMapping[2] = {0, 2};
+ /// origOpToCollapsedOpMapping[3] = {1, 0};
+ /// origOpToCollapsedOpMapping[4] = {1, 1};
+ /// ```
+ ///
+ ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
+ return origOpToCollapsedOpIterationDim;
}
- // For a dim in the original op, return the dim in the collapsed op, that it
- // is mapped to. Expectes `dim` to be start of a folded dimension sequence.
- unsigned getDimInCollapsedOpForStartOfFoldedDims(unsigned dim) {
- assert(foldedDimStartToSequenceMap.count(dim) &&
- "invalid start dim of folded dim sequence");
- return foldedDimStartToSequenceMap[dim];
+ /// Return the collapsed op iteration domain rank.
+ unsigned getCollapsedOpIterationRank() const {
+ return collapsedOpToOrigOpIterationDim.size();
}
private:
- /// Reassociation describing the folded iteration space dimensions.
- SmallVector<ReassociationIndices> iterationReassociation;
+ /// Map from the iteration domain index in collapsed op to the iteration
+ /// domain indices in the original op.
+ SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
- /// Map from the starting dimensions of the folded dimension sequences to
- /// their index in `iterationReassociation`.
- llvm::DenseMap<unsigned, unsigned> foldedDimStartToSequenceMap;
+ /// Map from iteration domain index in the original op to the iteration domain
+ /// index in the collapsed op.
+ SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
};
} // namespace
@@ -1336,10 +1455,10 @@ class CollapsingInfo {
/// iterator types and collapsed dimensions.
static SmallVector<StringRef>
getCollapsedOpIteratorTypes(ArrayRef<Attribute> iteratorTypes,
- CollapsingInfo &collapsingInfo) {
+ const CollapsingInfo &collapsingInfo) {
SmallVector<StringRef> collapsedIteratorTypes;
for (ReassociationIndicesRef foldedIterDims :
- collapsingInfo.getReassociationIndices()) {
+ collapsingInfo.getCollapsedOpToOrigOpMapping()) {
assert(!foldedIterDims.empty() &&
"reassociation indices expected to have non-empty sets");
// Just pick the iterator type of the first folded dim. Pre-condition checks
@@ -1353,35 +1472,50 @@ getCollapsedOpIteratorTypes(ArrayRef<Attribute> iteratorTypes,
/// Compute the indexing map in the collapsed op that corresponds to the given
/// `indexingMap` of the original operation.
-static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap,
- CollapsingInfo &collapsingInfo) {
+static AffineMap
+getCollapsedOpIndexingMap(AffineMap indexingMap,
+ const CollapsingInfo &collapsingInfo) {
MLIRContext *context = indexingMap.getContext();
assert(indexingMap.isProjectedPermutation() &&
"expected indexing map to be projected permutation");
SmallVector<AffineExpr> resultExprs;
+ auto origOpToCollapsedOpMapping =
+ collapsingInfo.getOrigOpToCollapsedOpMapping();
for (auto expr : indexingMap.getResults()) {
unsigned dim = expr.cast<AffineDimExpr>().getPosition();
- if (collapsingInfo.isDimStartOfFoldedDims(dim)) {
- resultExprs.push_back(getAffineDimExpr(
- collapsingInfo.getDimInCollapsedOpForStartOfFoldedDims(dim),
- context));
- }
+ // If the dim is not the first of the collapsed dim, do nothing.
+ if (origOpToCollapsedOpMapping[dim].second != 0)
+ continue;
+ // The next n-dims are guaranteed to be collapsed. So just use the
+ // iteration dimension of the collapsed op.
+ resultExprs.push_back(
+ getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context));
}
- return AffineMap::get(collapsingInfo.getReassociationIndices().size(), 0,
+ return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
resultExprs, context);
}
/// Return the `reassociation` indices to use to collapse the operand when the
/// iteration space of a generic op is collapsed.
static SmallVector<ReassociationIndices>
-getOperandReassociation(AffineMap indexingMap, CollapsingInfo &collapsingInfo) {
+getOperandReassociation(AffineMap indexingMap,
+ const CollapsingInfo &collapsingInfo) {
unsigned counter = 0;
SmallVector<ReassociationIndices> operandReassociation;
- for (auto expr : indexingMap.getResults()) {
- unsigned dim = expr.cast<AffineDimExpr>().getPosition();
- if (collapsingInfo.isDimStartOfFoldedDims(dim)) {
+ auto origOpToCollapsedOpMapping =
+ collapsingInfo.getOrigOpToCollapsedOpMapping();
+ auto collapsedOpToOrigOpMapping =
+ collapsingInfo.getCollapsedOpToOrigOpMapping();
+ while (counter < indexingMap.getNumResults()) {
+ unsigned dim =
+ indexingMap.getResult(counter).cast<AffineDimExpr>().getPosition();
+ if (origOpToCollapsedOpMapping[dim].second == 0) {
+ // This is the start of a collapsed dimensions of the iteration that
+ // is gauranteed to be preserved in the indexing map. The number of folded
+ // dims is obtained from the collapsed op to original op mapping.
unsigned numFoldedDims =
- collapsingInfo.getFoldedDimsStartingAt(dim).size();
+ collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
+ .size();
auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
operandReassociation.emplace_back(range.begin(), range.end());
counter += numFoldedDims;
@@ -1393,7 +1527,7 @@ getOperandReassociation(AffineMap indexingMap, CollapsingInfo &collapsingInfo) {
/// Get the new value to use for a given `OpOperand` in the collapsed operation.
static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
OpOperand *opOperand,
- CollapsingInfo &collapsingInfo,
+ const CollapsingInfo &collapsingInfo,
OpBuilder &builder) {
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
SmallVector<ReassociationIndices> operandReassociation =
@@ -1414,7 +1548,7 @@ static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
/// Modify the `linalg.index` operations in the original generic op, to its
/// value in the collapsed operation.
void generateCollapsedIndexingRegion(Location loc, Block *block,
- CollapsingInfo &collapsingInfo,
+ const CollapsingInfo &collapsingInfo,
ValueRange loopRange,
PatternRewriter &rewriter) {
OpBuilder::InsertionGuard g(rewriter);
@@ -1431,7 +1565,8 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
// i1 = (i_{folded} / d2) % d1
// i0 = i_{folded} / (d1 * d2)
llvm::DenseMap<unsigned, Value> indexReplacementVals;
- for (auto &foldedDims : enumerate(collapsingInfo.getReassociationIndices())) {
+ for (auto &foldedDims :
+ enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
ReassociationIndicesRef foldedDimsRef(foldedDims.value());
Value newIndexVal =
rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
@@ -1451,24 +1586,22 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
}
/// Implementation of fusion with reshape operation by collapsing dimensions.
-static Optional<SmallVector<Value>>
-fuseWithReshapeByCollapsing(GenericOp genericOp, Operation *reshapeOp,
- OpOperand *fusableOpOperand,
- PatternRewriter &rewriter) {
- SmallVector<ReassociationIndices> reassociation =
- isa<tensor::CollapseShapeOp>(reshapeOp)
- ? cast<tensor::CollapseShapeOp>(reshapeOp).getReassociationIndices()
- : cast<tensor::ExpandShapeOp>(reshapeOp).getReassociationIndices();
- assert(isFusableWithReshapeByDimCollapse(genericOp, fusableOpOperand,
- reassociation) &&
- "preconditions for fusing with reshape by collapse failed");
-
- CollapsingInfo collapsingInfo(getDomainReassociation(
- genericOp.getTiedIndexingMap(fusableOpOperand), reassociation));
- // Check for trivial no transformation cases. In that case return nothing.
- if (collapsingInfo.getReassociationIndices().size() ==
- genericOp.getNumLoops())
- return llvm::None;
+static FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
+ GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
+ OpOperand *fusableOpOperand, PatternRewriter &rewriter) {
+ // Bail on trivial no-op cases.
+ if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() ||
+ llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
+ return foldedDims.size() <= 1;
+ }))
+ return failure();
+
+ CollapsingInfo collapsingInfo;
+ if (failed(collapsingInfo.initialize(genericOp.getNumLoops(),
+ foldedIterationDims))) {
+ return rewriter.notifyMatchFailure(
+ genericOp, "illegal to collapse specified dimensions");
+ }
// Get the iterator types for the operand.
SmallVector<StringRef> iteratorTypes = getCollapsedOpIteratorTypes(
@@ -1480,8 +1613,7 @@ fuseWithReshapeByCollapsing(GenericOp genericOp, Operation *reshapeOp,
return getCollapsedOpIndexingMap(map, collapsingInfo);
}));
- Location loc =
- rewriter.getFusedLoc({genericOp->getLoc(), reshapeOp->getLoc()});
+ Location loc = genericOp->getLoc();
// Get the input operands.
auto inputOperands = llvm::to_vector(
@@ -1576,14 +1708,17 @@ class FoldWithProducerReshapeOpByCollapsing
if (!reshapeOp)
continue;
- if (!isFusableWithReshapeByDimCollapse(
- genericOp, opOperand, reshapeOp.getReassociationIndices()) ||
+ SmallVector<ReassociationIndices> collapsableIterationDims =
+ getCollapsableIterationSpaceDims(genericOp, opOperand,
+ reshapeOp.getReassociationIndices());
+ if (collapsableIterationDims.empty() ||
!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)) {
continue;
}
- Optional<SmallVector<Value>> replacements = fuseWithReshapeByCollapsing(
- genericOp, reshapeOp, opOperand, rewriter);
+ Optional<SmallVector<Value>> replacements =
+ collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
+ opOperand, rewriter);
if (!replacements) {
return rewriter.notifyMatchFailure(
genericOp, "failed to do the fusion by collapsing transformation");
diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index c45c5b296a781..ee49c929af0e2 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -399,3 +399,128 @@ func @zero_D_test(%arg0: tensor<f32>) -> tensor<1xf32> {
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[EXPAND]] :
// CHECK: return %[[GENERIC]]
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d1, d0, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func @fuse_only_one_reassociation(%arg0 : tensor<?x?xf32>, %arg1 : tensor<4x?x?x8xf32>) -> tensor<4x?x?x8xf32> {
+ %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] : tensor<?x?xf32> into tensor<?x4x?x8xf32>
+ %1 = linalg.generic {
+ indexing_maps = [#map0, #map1, #map1],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ ins(%0, %arg1 : tensor<?x4x?x8xf32>, tensor<4x?x?x8xf32>)
+ outs(%arg1 : tensor<4x?x?x8xf32>) {
+ ^bb0(%b0: f32, %b1 : f32, %b2 : f32):
+ %2 = arith.addf %b0, %b1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<4x?x?x8xf32>
+ return %1 : tensor<4x?x?x8xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK: func @fuse_only_one_reassociation(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<4x?x?x8xf32>
+// CHECK-DAG: %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}}
+// CHECK-DAG: %[[COLLAPSE_ARG0:.+]] = tensor.collapse_shape %[[EXPAND_ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
+// CHECK-DAG: %[[COLLAPSE_ARG1_0:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
+// CHECK-DAG: %[[COLLAPSE_ARG1_1:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[COLLAPSE_ARG0]], %[[COLLAPSE_ARG1_0]] :
+// CHECK-SAME: outs(%[[COLLAPSE_ARG1_1]] :
+// CHECK: %[[EXPAND_GENERIC:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1], [2, 3]{{\]}}
+// CHECK: return %[[EXPAND_GENERIC]]
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d3, d1, d0, d2)>
+func @fold_non_consecutive_dims(%arg0 : tensor<?x?xi32>) -> tensor<?x8x?x4xi32> {
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] : tensor<?x?xi32> into tensor<?x4x?x8xi32>
+ %d0 = tensor.dim %0, %c0 : tensor<?x4x?x8xi32>
+ %d1 = tensor.dim %0, %c2 : tensor<?x4x?x8xi32>
+ %init = linalg.init_tensor [%d1, 8, %d0, 4] : tensor<?x8x?x4xi32>
+ %1 = linalg.generic {
+ indexing_maps = [#map0, #map1],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ ins(%0 : tensor<?x4x?x8xi32>) outs(%init : tensor<?x8x?x4xi32>) {
+ ^bb0(%b0 : i32, %b1 : i32):
+ %2 = linalg.index 0 : index
+ %3 = linalg.index 1 : index
+ %4 = linalg.index 2 : index
+ %5 = linalg.index 3 : index
+ %6 = arith.addi %2, %3 : index
+ %7 = arith.addi %6, %4 : index
+ %8 = arith.addi %7, %5 : index
+ %9 = arith.index_cast %8 : index to i32
+ linalg.yield %9: i32
+ } -> tensor<?x8x?x4xi32>
+ return %1 : tensor<?x8x?x4xi32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK: func @fold_non_consecutive_dims(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>)
+// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
+// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
+// CHECK: %[[INIT:.+]] = linalg.init_tensor
+// CHECK: %[[COLLAPSE_INIT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2, 3]{{\]}}
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel"]
+// CHECK-SAME: ins(%[[ARG0]] :
+// CHECK-SAME: outs(%[[COLLAPSE_INIT]] :
+// CHECK-NEXT: ^bb{{[0-9]}}
+// CHECK: %[[ID0:.+]] = linalg.index 0
+// CHECK-DAG: %[[T0:.+]] = arith.remui %[[ID0]], %[[C4]]
+// CHECK-DAG: %[[T1:.+]] = arith.divui %[[ID0]], %[[C4]]
+// CHECK: %[[ID1:.+]] = linalg.index 1
+// CHECK-DAG: %[[T2:.+]] = arith.remui %[[ID1]], %[[C8]]
+// CHECK-DAG: %[[T3:.+]] = arith.divui %[[ID1]], %[[C8]]
+// CHECK-DAG: %[[T4:.+]] = arith.addi %[[T1]], %[[T2]]
+// CHECK-DAG: %[[T5:.+]] = arith.addi %[[T4]], %[[T0]]
+// CHECK-DAG: %[[T6:.+]] = arith.addi %[[T5]], %[[T3]]
+// CHECK-DAG: %[[T7:.+]] = arith.index_cast %[[T6]]
+// CHECK: linalg.yield %[[T7]]
+// CHECK: %[[EXPAND_GENERIC:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0, 1], [2, 3]{{\]}}
+// CHECK: return %[[EXPAND_GENERIC]]
+
+// -----
+
+// None of the folded iteration space dims are contiguous reduction dimensions.
+// So no change in the code.
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>
+#map1 = affine_map<(d0, d1, d2, d3) -> ()>
+func @no_fold_non_consecutive_reduction_dims(%arg0 : tensor<?x?xi32>) -> tensor<i32> {
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] : tensor<?x?xi32> into tensor<?x4x?x8xi32>
+ %init = linalg.init_tensor [] : tensor<i32>
+ %1 = linalg.generic {
+ indexing_maps = [#map0, #map1],
+ iterator_types = ["reduction", "reduction", "reduction", "reduction"]}
+ ins(%0 : tensor<?x4x?x8xi32>) outs(%init : tensor<i32>) {
+ ^bb0(%b0 : i32, %b1 : i32):
+ %2 = linalg.index 0 : index
+ %3 = linalg.index 1 : index
+ %4 = linalg.index 2 : index
+ %5 = linalg.index 3 : index
+ %6 = arith.addi %2, %3 : index
+ %7 = arith.addi %6, %4 : index
+ %8 = arith.addi %7, %5 : index
+ %9 = arith.index_cast %8 : index to i32
+ linalg.yield %9: i32
+ } -> tensor<i32>
+ return %1 : tensor<i32>
+}
+// CHECK: func @no_fold_non_consecutive_reduction_dims(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>)
+// CHECK: %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}}
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[EXPAND_ARG0]] :
+// CHECK: return %[[GENERIC]]
More information about the Mlir-commits
mailing list