[Mlir-commits] [mlir] b40e901 - [mlir][Linalg] Allow collapsing subset of the reassociations when fusing by collapsing.

Mahesh Ravishankar llvmlistbot at llvm.org
Tue Apr 12 11:56:42 PDT 2022


Author: Mahesh Ravishankar
Date: 2022-04-12T18:56:32Z
New Revision: b40e901333b903fd71f17c3314d3e40f8abde074

URL: https://github.com/llvm/llvm-project/commit/b40e901333b903fd71f17c3314d3e40f8abde074
DIFF: https://github.com/llvm/llvm-project/commit/b40e901333b903fd71f17c3314d3e40f8abde074.diff

LOG: [mlir][Linalg] Allow collapsing subset of the reassociations when fusing by collapsing.

This change generalizes the fusion of `tensor.expand_shape` ->
`linalg.generic` op by collapsing to handle cases where only a subset
of the reassociations specified in the `tensor.expand_shape` are valid
to be collapsed.
The method that does the collapsing is refactored to allow it to be a
generic utility when required.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D123153

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 45d57c378e4d0..b97e654bae4b5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1151,43 +1151,24 @@ struct FoldReshapeWithGenericOpByExpansion
 // contraction of dimensions.
 //===---------------------------------------------------------------------===//
 
-/// For an `indexingMap` that is a projected permutation, if the range is to be
-/// collapsed using the given `reassociation`, get the reassociation in the
-/// domain that would keep the map a projected permutation.
-static SmallVector<ReassociationIndices>
+/// For a given list of indices in the range of the `indexingMap` that are
+/// folded, return the indices of the corresponding domain. Return `llvm::None`
+/// on failure. Ensures that all the elements of the returned reassociation are
+/// distinct.
+static ReassociationIndices
 getDomainReassociation(AffineMap indexingMap,
-                       ArrayRef<ReassociationIndices> rangeReassociation) {
+                       ReassociationIndicesRef rangeReassociation) {
   assert(indexingMap.isProjectedPermutation() &&
-         "expected projected permutation map");
-  unsigned counter = 0;
-  SmallVector<ReassociationIndices> domainReassociation;
-  llvm::SmallDenseSet<unsigned, 4> processedDomainDims;
-  // Iterate over the reassociation indices.
-  for (ReassociationIndicesRef foldedRangeDims : rangeReassociation) {
-    ReassociationIndices foldedDomainDims;
-    for (auto rangeDim : foldedRangeDims) {
-      (void)rangeDim;
-      AffineDimExpr dimExpr =
-          indexingMap.getResult(counter++).cast<AffineDimExpr>();
-      foldedDomainDims.push_back(dimExpr.getPosition());
-      processedDomainDims.insert(dimExpr.getPosition());
-    }
-    domainReassociation.emplace_back(std::move(foldedDomainDims));
-  }
-  // Fill in the missing domain dims.
-  for (auto dim : llvm::seq<unsigned>(0, indexingMap.getNumDims())) {
-    if (processedDomainDims.count(dim))
-      continue;
-    ReassociationIndices vec = {dim};
-    domainReassociation.emplace_back(std::move(vec));
-  }
+         "expected projected permutation");
 
-  // Sort the reassociation using the first dimension of the folded range to
-  // not create unnecessary transposes.
-  llvm::sort(domainReassociation,
-             [](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) {
-               return lhs[0] < rhs[0];
-             });
+  ReassociationIndices domainReassociation = llvm::to_vector<4>(
+      llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t {
+        return indexingMap.getResults()[pos]
+            .cast<AffineDimExpr>()
+            .getPosition();
+      }));
+  // The projected permutation semantics ensures that there is no repetition of
+  // the domain indices.
   return domainReassociation;
 }
 
@@ -1235,100 +1216,238 @@ static bool isDimSequencePreserved(AffineMap indexingMap,
   return true;
 }
 
-// Check if a generic op can be fused along an operand by collapsing dimensions.
-static bool isFusableWithReshapeByDimCollapse(
-    GenericOp genericOp, OpOperand *fusableOperand,
-    ArrayRef<ReassociationIndices> reassociation) {
+// Return the list of dimensions of the iteration domain that can be
+// collapsed to allow for fusion with the a producer that is an expand_shape
+// operation. If all dimensions created by expansion can be collapsed in the
+// iteration space then the reshape is defunct.
+//
+// Example:
+//
+// ```mlir
+// #map = affine_map<(d0, d1) -> (d0, d1)>
+// %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
+// %2 = linalg.init_tensor [..] : tensor<?x4xf32>
+// %3 = linalg.generic {
+//     indexing_maps = [#map, #map],
+//     iterator_types = ["parallel" ,"parallel"]}
+//     ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. }
+// ```
+//
+// can be fused by collapsing the dimensions of the iteration space.
+//
+// ```mlir
+// #map = affine_map<(d0) -> (d0)>
+// %2 = linalg.init_tensor [..] : tensor<?xf32>
+// %3 = linalg.generic {
+//     indexing_maps = [#map, #map],
+//     iterator_types = ["parallel"]}
+//     ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. }
+// %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
+// ```
+//
+// In the following example,
+//
+// ```mlir
+// #map0 = affine_map<(d0, d1) -> (d0, d1)>
+// #map1 = affine_map<(d0, d1) -> (d1, d0)>
+// %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32>
+// %2 = linalg.init_tensor [..] : tensor<4x?xf32>
+// %2 = linalg.generic {
+//     indexing_maps = [#map0, #map1],
+//     iterator_types = ["parallel" ,"parallel"]}
+//     ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. }
+// ```
+//
+// the reshape cannot be fused with the generic op by collapsing the op
+// dimensions since the indexing maps will have to contain mods and divs
+// to preserve the accesses pattern. When no dimensions of the iteration
+// space are collapsable and empty vector is returned.
+static SmallVector<ReassociationIndices>
+getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
+                                 ArrayRef<ReassociationIndices> reassociation) {
   // Some basic checks for this fusion to be valid.
   if (!genericOp.hasTensorSemantics() || genericOp.getNumOutputs() != 1)
-    return false;
+    return {};
 
   if (!llvm::all_of(genericOp.getIndexingMaps(), [](AffineMap map) {
         return map.isProjectedPermutation();
       })) {
-    return false;
+    return {};
   }
 
-  // Get the reassociation for the iteration space.
-  SmallVector<ReassociationIndices> iterationReassociation =
-      getDomainReassociation(genericOp.getTiedIndexingMap(fusableOperand),
-                             reassociation);
-  if (iterationReassociation.empty()) {
-    // If the domain reassociation indices is empty, then this is a scalar op.
-    // Nothing to do.
-    return false;
+  // Compute all the loops with the reduction iterator types.
+  SmallVector<int64_t> reductionDims;
+  for (auto iteratorType : llvm::enumerate(genericOp.iterator_types())) {
+    if (isReductionIterator(iteratorType.value())) {
+      reductionDims.push_back(iteratorType.index());
+    }
   }
 
+  llvm::SmallDenseSet<unsigned, 4> processedIterationDims;
+  AffineMap indexingMap = genericOp.getTiedIndexingMap(fusableOperand);
   auto iteratorTypes = genericOp.iterator_types().getValue();
-  ArrayRef<Attribute> iteratorTypesRef(iteratorTypes);
-  for (ReassociationIndicesRef foldedIterDims : iterationReassociation) {
-    // Check that for all indexing maps, the folded dimensions sequence is
-    // preserved.
-    if (!llvm::all_of(genericOp.getIndexingMaps(), [&](AffineMap indexingMap) {
-          return isDimSequencePreserved(indexingMap, foldedIterDims);
+  SmallVector<ReassociationIndices> iterationSpaceReassociation;
+  for (ReassociationIndicesRef foldedRangeDims : reassociation) {
+    assert(!foldedRangeDims.empty() && "unexpected empty reassociation");
+
+    // Ignore dims that are not folded.
+    if (foldedRangeDims.size() == 1)
+      continue;
+
+    ReassociationIndices foldedIterationSpaceDims =
+        getDomainReassociation(indexingMap, foldedRangeDims);
+
+    // Check that the folded iteration dims do not contain already processed
+    // dims.
+    if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
+          return processedIterationDims.count(dim);
         }))
-      return false;
-    unsigned startDim = foldedIterDims[0];
-    ArrayRef<Attribute> foldedIteratorTypes =
-        iteratorTypesRef.drop_front(startDim).take_front(foldedIterDims.size());
-    // Check that all folded iterator types are either all parallel, or all
-    // reduction.
-    if (!llvm::all_of(
-            foldedIteratorTypes,
-            [](Attribute attr) { return isParallelIterator(attr); }) &&
-        !llvm::all_of(foldedIteratorTypes,
-                      [](Attribute attr) { return isReductionIterator(attr); }))
-      return false;
+      continue;
+
+    // Check that all folded iterator types are all parallel or all reductions.
+    Attribute startIteratorType = iteratorTypes[foldedIterationSpaceDims[0]];
+    if (!isParallelIterator(startIteratorType) &&
+        !isReductionIterator(startIteratorType))
+      continue;
+    if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) {
+          return iteratorTypes[dim] != startIteratorType;
+        }))
+      continue;
+
+    // If the folded dimensions correspond to a "reduction" iterator type,
+    // the folded dimensions need to be "in-order". Strictly speaking this is
+    // not necessary, for reductions that are associative and commutative,  but
+    // using a more strict definition of reduction for now.
+    if (isReductionIterator(startIteratorType)) {
+      bool isContiguous = false;
+      for (auto startDim : llvm::enumerate(reductionDims)) {
+        // Move window in `reductionDims` to start of the folded iteration dims.
+        if (startDim.value() != foldedIterationSpaceDims[0])
+          continue;
+        // If sizes doesnt match, trivial not contiguous. This condition should
+        // not be hit.
+        if (startDim.index() + foldedIterationSpaceDims.size() >
+            reductionDims.size())
+          break;
+        // Check that the contiguity is maintained.
+        isContiguous = true;
+        for (auto foldedDim : llvm::enumerate(foldedIterationSpaceDims)) {
+          if (reductionDims[foldedDim.index() + startDim.index()] !=
+              foldedDim.value()) {
+            isContiguous = false;
+            break;
+          }
+        }
+        break;
+      }
+      if (!isContiguous)
+        continue;
+    }
+
+    // Check that the sequence is preserved in all indexing maps.
+    if (llvm::any_of(genericOp.getIndexingMaps(), [&](AffineMap indexingMap) {
+          return !isDimSequencePreserved(indexingMap, foldedIterationSpaceDims);
+        }))
+      continue;
+
+    processedIterationDims.insert(foldedIterationSpaceDims.begin(),
+                                  foldedIterationSpaceDims.end());
+    iterationSpaceReassociation.emplace_back(
+        std::move(foldedIterationSpaceDims));
   }
-  return true;
+
+  return iterationSpaceReassociation;
 }
 
 /// Helper class to carry state while collapsing the `linalg.generic` op.
 namespace {
 class CollapsingInfo {
 public:
-  CollapsingInfo(SmallVector<ReassociationIndices> &&reassociation) {
-    iterationReassociation = std::move(reassociation);
-    for (const auto &foldedIterDims : enumerate(iterationReassociation)) {
-      foldedDimStartToSequenceMap[foldedIterDims.value()[0]] =
-          foldedIterDims.index();
+  LogicalResult initialize(unsigned origNumLoops,
+                           ArrayRef<ReassociationIndices> foldedIterationDims) {
+    llvm::SmallDenseSet<int64_t, 4> processedDims;
+    // Find all the dims that are folded.
+    for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) {
+      if (foldedIterationDim.empty())
+        continue;
+      // If the folded dims contain dims already folded, that's illegal
+      // specification. Repetition within a list is also illegal.
+      for (auto dim : foldedIterationDim) {
+        if (dim >= origNumLoops)
+          return failure();
+        if (processedDims.count(dim))
+          return failure();
+        processedDims.insert(dim);
+      }
+      collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(),
+                                                   foldedIterationDim.end());
+    }
+    if (processedDims.size() > origNumLoops)
+      return failure();
+
+    // Add all the preserved dims of the original op as single
+    // elements to `collapsedOpToOrigOpIterationDim`.
+    for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) {
+      if (processedDims.count(dim))
+        continue;
+      collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim});
     }
-  }
 
-  // Returns the iteration space reassociation.
-  ArrayRef<ReassociationIndices> getReassociationIndices() {
-    return iterationReassociation;
+    llvm::sort(collapsedOpToOrigOpIterationDim,
+               [&](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) {
+                 return lhs[0] < rhs[0];
+               });
+    origOpToCollapsedOpIterationDim.resize(origNumLoops);
+    for (auto foldedDims : llvm::enumerate(collapsedOpToOrigOpIterationDim)) {
+      for (auto dim : enumerate(foldedDims.value()))
+        origOpToCollapsedOpIterationDim[dim.value()] =
+            std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index());
+    }
+    return success();
   }
 
-  // Returns true if the given dimension is the start of a sequence of folded
-  // dimensions.
-  bool isDimStartOfFoldedDims(unsigned dim) {
-    return foldedDimStartToSequenceMap.count(dim);
+  /// Return mapping from collapsed loop domain to original loop domain.
+  ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const {
+    return collapsedOpToOrigOpIterationDim;
   }
 
-  // Return the folded dimensions starting at `dim`.
-  ReassociationIndicesRef getFoldedDimsStartingAt(unsigned dim) {
-    assert(foldedDimStartToSequenceMap.count(dim) &&
-           "invalid start dim of folded dim "
-           "sequence");
-    return iterationReassociation[foldedDimStartToSequenceMap[dim]];
+  /// Return mapping from original loop domain to collapsed loop domain. The
+  /// mapping is a pair. First value is the dimension in the collapsed loop that
+  /// the original loop is mapped to. Second is the relative position in folded
+  /// list of this domain. For example if the original loop domain is 3D, and
+  /// the collapsed loop domain is folding all of it, i.e.
+  ///
+  /// ```
+  /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]`
+  /// ```
+  ///
+  /// then
+  ///
+  /// ```
+  ///  origOpToCollapsedOpMapping[0] = {0, 0};
+  ///  origOpToCollapsedOpMapping[1] = {0, 1};
+  ///  origOpToCollapsedOpMapping[2] = {0, 2};
+  ///  origOpToCollapsedOpMapping[3] = {1, 0};
+  ///  origOpToCollapsedOpMapping[4] = {1, 1};
+  /// ```
+  ///
+  ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const {
+    return origOpToCollapsedOpIterationDim;
   }
 
-  // For a dim in the original op, return the dim in the collapsed op, that it
-  // is mapped to. Expectes `dim` to be start of a folded dimension sequence.
-  unsigned getDimInCollapsedOpForStartOfFoldedDims(unsigned dim) {
-    assert(foldedDimStartToSequenceMap.count(dim) &&
-           "invalid start dim of folded dim sequence");
-    return foldedDimStartToSequenceMap[dim];
+  /// Return the collapsed op iteration domain rank.
+  unsigned getCollapsedOpIterationRank() const {
+    return collapsedOpToOrigOpIterationDim.size();
   }
 
 private:
-  /// Reassociation describing the folded iteration space dimensions.
-  SmallVector<ReassociationIndices> iterationReassociation;
+  /// Map from the iteration domain index in collapsed op to the iteration
+  /// domain indices in the original op.
+  SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim;
 
-  /// Map from the starting dimensions of the folded dimension sequences to
-  /// their index in `iterationReassociation`.
-  llvm::DenseMap<unsigned, unsigned> foldedDimStartToSequenceMap;
+  /// Map from iteration domain index in the original op to the iteration domain
+  /// index in the collapsed op.
+  SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim;
 };
 } // namespace
 
@@ -1336,10 +1455,10 @@ class CollapsingInfo {
 /// iterator types and collapsed dimensions.
 static SmallVector<StringRef>
 getCollapsedOpIteratorTypes(ArrayRef<Attribute> iteratorTypes,
-                            CollapsingInfo &collapsingInfo) {
+                            const CollapsingInfo &collapsingInfo) {
   SmallVector<StringRef> collapsedIteratorTypes;
   for (ReassociationIndicesRef foldedIterDims :
-       collapsingInfo.getReassociationIndices()) {
+       collapsingInfo.getCollapsedOpToOrigOpMapping()) {
     assert(!foldedIterDims.empty() &&
            "reassociation indices expected to have non-empty sets");
     // Just pick the iterator type of the first folded dim. Pre-condition checks
@@ -1353,35 +1472,50 @@ getCollapsedOpIteratorTypes(ArrayRef<Attribute> iteratorTypes,
 
 /// Compute the indexing map in the collapsed op that corresponds to the given
 /// `indexingMap` of the original operation.
-static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap,
-                                           CollapsingInfo &collapsingInfo) {
+static AffineMap
+getCollapsedOpIndexingMap(AffineMap indexingMap,
+                          const CollapsingInfo &collapsingInfo) {
   MLIRContext *context = indexingMap.getContext();
   assert(indexingMap.isProjectedPermutation() &&
          "expected indexing map to be projected permutation");
   SmallVector<AffineExpr> resultExprs;
+  auto origOpToCollapsedOpMapping =
+      collapsingInfo.getOrigOpToCollapsedOpMapping();
   for (auto expr : indexingMap.getResults()) {
     unsigned dim = expr.cast<AffineDimExpr>().getPosition();
-    if (collapsingInfo.isDimStartOfFoldedDims(dim)) {
-      resultExprs.push_back(getAffineDimExpr(
-          collapsingInfo.getDimInCollapsedOpForStartOfFoldedDims(dim),
-          context));
-    }
+    // If the dim is not the first of the collapsed dim, do nothing.
+    if (origOpToCollapsedOpMapping[dim].second != 0)
+      continue;
+    // The next n-dims are guaranteed to be collapsed. So just use the
+    // iteration dimension of the collapsed op.
+    resultExprs.push_back(
+        getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context));
   }
-  return AffineMap::get(collapsingInfo.getReassociationIndices().size(), 0,
+  return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0,
                         resultExprs, context);
 }
 
 /// Return the `reassociation` indices to use to collapse the operand when the
 /// iteration space of a generic op is collapsed.
 static SmallVector<ReassociationIndices>
-getOperandReassociation(AffineMap indexingMap, CollapsingInfo &collapsingInfo) {
+getOperandReassociation(AffineMap indexingMap,
+                        const CollapsingInfo &collapsingInfo) {
   unsigned counter = 0;
   SmallVector<ReassociationIndices> operandReassociation;
-  for (auto expr : indexingMap.getResults()) {
-    unsigned dim = expr.cast<AffineDimExpr>().getPosition();
-    if (collapsingInfo.isDimStartOfFoldedDims(dim)) {
+  auto origOpToCollapsedOpMapping =
+      collapsingInfo.getOrigOpToCollapsedOpMapping();
+  auto collapsedOpToOrigOpMapping =
+      collapsingInfo.getCollapsedOpToOrigOpMapping();
+  while (counter < indexingMap.getNumResults()) {
+    unsigned dim =
+        indexingMap.getResult(counter).cast<AffineDimExpr>().getPosition();
+    if (origOpToCollapsedOpMapping[dim].second == 0) {
+      // This is the start of a collapsed dimensions of the iteration that
+      // is gauranteed to be preserved in the indexing map. The number of folded
+      // dims is obtained from the collapsed op to original op mapping.
       unsigned numFoldedDims =
-          collapsingInfo.getFoldedDimsStartingAt(dim).size();
+          collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first]
+              .size();
       auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
       operandReassociation.emplace_back(range.begin(), range.end());
       counter += numFoldedDims;
@@ -1393,7 +1527,7 @@ getOperandReassociation(AffineMap indexingMap, CollapsingInfo &collapsingInfo) {
 /// Get the new value to use for a given `OpOperand` in the collapsed operation.
 static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
                                    OpOperand *opOperand,
-                                   CollapsingInfo &collapsingInfo,
+                                   const CollapsingInfo &collapsingInfo,
                                    OpBuilder &builder) {
   AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
   SmallVector<ReassociationIndices> operandReassociation =
@@ -1414,7 +1548,7 @@ static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
 /// Modify the `linalg.index` operations in the original generic op, to its
 /// value in the collapsed operation.
 void generateCollapsedIndexingRegion(Location loc, Block *block,
-                                     CollapsingInfo &collapsingInfo,
+                                     const CollapsingInfo &collapsingInfo,
                                      ValueRange loopRange,
                                      PatternRewriter &rewriter) {
   OpBuilder::InsertionGuard g(rewriter);
@@ -1431,7 +1565,8 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
   //   i1 = (i_{folded} / d2) % d1
   //   i0 = i_{folded} / (d1 * d2)
   llvm::DenseMap<unsigned, Value> indexReplacementVals;
-  for (auto &foldedDims : enumerate(collapsingInfo.getReassociationIndices())) {
+  for (auto &foldedDims :
+       enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
     ReassociationIndicesRef foldedDimsRef(foldedDims.value());
     Value newIndexVal =
         rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
@@ -1451,24 +1586,22 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
 }
 
 /// Implementation of fusion with reshape operation by collapsing dimensions.
-static Optional<SmallVector<Value>>
-fuseWithReshapeByCollapsing(GenericOp genericOp, Operation *reshapeOp,
-                            OpOperand *fusableOpOperand,
-                            PatternRewriter &rewriter) {
-  SmallVector<ReassociationIndices> reassociation =
-      isa<tensor::CollapseShapeOp>(reshapeOp)
-          ? cast<tensor::CollapseShapeOp>(reshapeOp).getReassociationIndices()
-          : cast<tensor::ExpandShapeOp>(reshapeOp).getReassociationIndices();
-  assert(isFusableWithReshapeByDimCollapse(genericOp, fusableOpOperand,
-                                           reassociation) &&
-         "preconditions for fusing with reshape by collapse failed");
-
-  CollapsingInfo collapsingInfo(getDomainReassociation(
-      genericOp.getTiedIndexingMap(fusableOpOperand), reassociation));
-  // Check for trivial no transformation cases. In that case return nothing.
-  if (collapsingInfo.getReassociationIndices().size() ==
-      genericOp.getNumLoops())
-    return llvm::None;
+static FailureOr<SmallVector<Value>> collapseGenericOpIterationDims(
+    GenericOp genericOp, ArrayRef<ReassociationIndices> foldedIterationDims,
+    OpOperand *fusableOpOperand, PatternRewriter &rewriter) {
+  // Bail on trivial no-op cases.
+  if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() ||
+      llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
+        return foldedDims.size() <= 1;
+      }))
+    return failure();
+
+  CollapsingInfo collapsingInfo;
+  if (failed(collapsingInfo.initialize(genericOp.getNumLoops(),
+                                       foldedIterationDims))) {
+    return rewriter.notifyMatchFailure(
+        genericOp, "illegal to collapse specified dimensions");
+  }
 
   // Get the iterator types for the operand.
   SmallVector<StringRef> iteratorTypes = getCollapsedOpIteratorTypes(
@@ -1480,8 +1613,7 @@ fuseWithReshapeByCollapsing(GenericOp genericOp, Operation *reshapeOp,
         return getCollapsedOpIndexingMap(map, collapsingInfo);
       }));
 
-  Location loc =
-      rewriter.getFusedLoc({genericOp->getLoc(), reshapeOp->getLoc()});
+  Location loc = genericOp->getLoc();
 
   // Get the input operands.
   auto inputOperands = llvm::to_vector(
@@ -1576,14 +1708,17 @@ class FoldWithProducerReshapeOpByCollapsing
       if (!reshapeOp)
         continue;
 
-      if (!isFusableWithReshapeByDimCollapse(
-              genericOp, opOperand, reshapeOp.getReassociationIndices()) ||
+      SmallVector<ReassociationIndices> collapsableIterationDims =
+          getCollapsableIterationSpaceDims(genericOp, opOperand,
+                                           reshapeOp.getReassociationIndices());
+      if (collapsableIterationDims.empty() ||
           !controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)) {
         continue;
       }
 
-      Optional<SmallVector<Value>> replacements = fuseWithReshapeByCollapsing(
-          genericOp, reshapeOp, opOperand, rewriter);
+      Optional<SmallVector<Value>> replacements =
+          collapseGenericOpIterationDims(genericOp, collapsableIterationDims,
+                                         opOperand, rewriter);
       if (!replacements) {
         return rewriter.notifyMatchFailure(
             genericOp, "failed to do the fusion by collapsing transformation");

diff  --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index c45c5b296a781..ee49c929af0e2 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -399,3 +399,128 @@ func @zero_D_test(%arg0: tensor<f32>) -> tensor<1xf32> {
 //      CHECK:   %[[GENERIC:.+]] = linalg.generic
 // CHECK-SAME:       ins(%[[EXPAND]] :
 //      CHECK:   return %[[GENERIC]]
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d1, d0, d2, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func @fuse_only_one_reassociation(%arg0 : tensor<?x?xf32>, %arg1 : tensor<4x?x?x8xf32>) -> tensor<4x?x?x8xf32> {
+  %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] : tensor<?x?xf32> into tensor<?x4x?x8xf32>
+  %1 = linalg.generic {
+      indexing_maps = [#map0, #map1, #map1],
+      iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+      ins(%0, %arg1 : tensor<?x4x?x8xf32>, tensor<4x?x?x8xf32>)
+      outs(%arg1 : tensor<4x?x?x8xf32>) {
+    ^bb0(%b0: f32, %b1 : f32, %b2 : f32):
+      %2 = arith.addf %b0, %b1 : f32
+      linalg.yield %2 : f32
+    } -> tensor<4x?x?x8xf32>
+  return %1 : tensor<4x?x?x8xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//      CHECK: func @fuse_only_one_reassociation(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG1:.+]]: tensor<4x?x?x8xf32>
+//  CHECK-DAG:   %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}}
+//  CHECK-DAG:   %[[COLLAPSE_ARG0:.+]] = tensor.collapse_shape %[[EXPAND_ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
+//  CHECK-DAG:   %[[COLLAPSE_ARG1_0:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
+//  CHECK-DAG:   %[[COLLAPSE_ARG1_1:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
+//      CHECK:   %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP1]]]
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel"]
+// CHECK-SAME:       ins(%[[COLLAPSE_ARG0]], %[[COLLAPSE_ARG1_0]] :
+// CHECK-SAME:       outs(%[[COLLAPSE_ARG1_1]] :
+//      CHECK:   %[[EXPAND_GENERIC:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1], [2, 3]{{\]}}
+//      CHECK:   return %[[EXPAND_GENERIC]]
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d3, d1, d0, d2)>
+func @fold_non_consecutive_dims(%arg0 : tensor<?x?xi32>) -> tensor<?x8x?x4xi32> {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] : tensor<?x?xi32> into tensor<?x4x?x8xi32>
+  %d0 = tensor.dim %0, %c0 : tensor<?x4x?x8xi32>
+  %d1 = tensor.dim %0, %c2 : tensor<?x4x?x8xi32>
+  %init = linalg.init_tensor [%d1, 8, %d0, 4] : tensor<?x8x?x4xi32>
+  %1 = linalg.generic {
+      indexing_maps = [#map0, #map1],
+      iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+      ins(%0 : tensor<?x4x?x8xi32>) outs(%init : tensor<?x8x?x4xi32>) {
+    ^bb0(%b0 : i32, %b1 : i32):
+      %2 = linalg.index 0 : index
+      %3 = linalg.index 1 : index
+      %4 = linalg.index 2 : index
+      %5 = linalg.index 3 : index
+      %6 = arith.addi %2, %3 : index
+      %7 = arith.addi %6, %4 : index
+      %8 = arith.addi %7, %5 : index
+      %9 = arith.index_cast %8 : index to i32
+      linalg.yield %9: i32
+    } -> tensor<?x8x?x4xi32>
+  return %1 : tensor<?x8x?x4xi32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1, d0)>
+//      CHECK: func @fold_non_consecutive_dims(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xi32>)
+//  CHECK-DAG:   %[[C4:.+]] = arith.constant 4 : index
+//  CHECK-DAG:   %[[C8:.+]] = arith.constant 8 : index
+//      CHECK:   %[[INIT:.+]] = linalg.init_tensor
+//      CHECK:   %[[COLLAPSE_INIT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2, 3]{{\]}}
+//      CHECK:   %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME:       iterator_types = ["parallel", "parallel"]
+// CHECK-SAME:       ins(%[[ARG0]] :
+// CHECK-SAME:       outs(%[[COLLAPSE_INIT]] :
+// CHECK-NEXT:   ^bb{{[0-9]}}
+//      CHECK:       %[[ID0:.+]] = linalg.index 0
+//  CHECK-DAG:       %[[T0:.+]] = arith.remui %[[ID0]], %[[C4]]
+//  CHECK-DAG:       %[[T1:.+]] = arith.divui %[[ID0]], %[[C4]]
+//      CHECK:       %[[ID1:.+]] = linalg.index 1
+//  CHECK-DAG:       %[[T2:.+]] = arith.remui %[[ID1]], %[[C8]]
+//  CHECK-DAG:       %[[T3:.+]] = arith.divui %[[ID1]], %[[C8]]
+//  CHECK-DAG:       %[[T4:.+]] = arith.addi %[[T1]], %[[T2]]
+//  CHECK-DAG:       %[[T5:.+]] = arith.addi %[[T4]], %[[T0]]
+//  CHECK-DAG:       %[[T6:.+]] = arith.addi %[[T5]], %[[T3]]
+//  CHECK-DAG:       %[[T7:.+]] = arith.index_cast %[[T6]]
+//      CHECK:       linalg.yield %[[T7]]
+//      CHECK:   %[[EXPAND_GENERIC:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0, 1], [2, 3]{{\]}}
+//      CHECK:   return %[[EXPAND_GENERIC]]
+
+// -----
+
+// None of the folded iteration space dims are contiguous reduction dimensions.
+// So no change in the code.
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>
+#map1 = affine_map<(d0, d1, d2, d3) -> ()>
+func @no_fold_non_consecutive_reduction_dims(%arg0 : tensor<?x?xi32>) -> tensor<i32> {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] : tensor<?x?xi32> into tensor<?x4x?x8xi32>
+  %init = linalg.init_tensor [] : tensor<i32>
+  %1 = linalg.generic {
+      indexing_maps = [#map0, #map1],
+      iterator_types = ["reduction", "reduction", "reduction", "reduction"]}
+      ins(%0 : tensor<?x4x?x8xi32>) outs(%init : tensor<i32>) {
+    ^bb0(%b0 : i32, %b1 : i32):
+      %2 = linalg.index 0 : index
+      %3 = linalg.index 1 : index
+      %4 = linalg.index 2 : index
+      %5 = linalg.index 3 : index
+      %6 = arith.addi %2, %3 : index
+      %7 = arith.addi %6, %4 : index
+      %8 = arith.addi %7, %5 : index
+      %9 = arith.index_cast %8 : index to i32
+      linalg.yield %9: i32
+    } -> tensor<i32>
+  return %1 : tensor<i32>
+}
+//      CHECK: func @no_fold_non_consecutive_reduction_dims(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?xi32>)
+//      CHECK:   %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}}
+//      CHECK:   %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:       ins(%[[EXPAND_ARG0]] :
+//      CHECK:   return %[[GENERIC]]


        


More information about the Mlir-commits mailing list