[Mlir-commits] [mlir] [mlir][linalg] Allow fusing reshapes with parallel operands (PR #130148)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 6 10:47:17 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Ian Wood (IanWood1)

<details>
<summary>Changes</summary>

Removes the condition that checks that operand is not indexed by reduction iterators which allows for more fine-grained control via the reshape fusion control function. For example, users could allow fusing reshapes expand the M/N dims of a matmul but not the K dims (or preserve the current behavior by not fusing at all).

---
Full diff: https://github.com/llvm/llvm-project/pull/130148.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+1-6) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index a45b5c43f5d33..337fd8f3a0ac1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -566,7 +566,6 @@ static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
   // - All the indexing maps for operands and results are projected
   //   permutations.
   // - The fused tensor is not a scalar.
-  // - All the loops for the reshaped operand are parallel loops.
   SmallVector<utils::IteratorType> iteratorTypes =
       linalgOp.getIteratorTypesArray();
   AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
@@ -577,11 +576,7 @@ static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
                             .getValue()
                             .isProjectedPermutation();
                       }) &&
-         operandMap.getNumResults() > 0 &&
-         llvm::all_of(operandMap.getResults(), [&](AffineExpr expr) {
-           return isParallelIterator(
-               iteratorTypes[cast<AffineDimExpr>(expr).getPosition()]);
-         });
+         operandMap.getNumResults() > 0;
 }
 
 namespace {

``````````

</details>


https://github.com/llvm/llvm-project/pull/130148


More information about the Mlir-commits mailing list