[Mlir-commits] [mlir] 70b95d1 - [mlir][linalg] Retain Op Type of linalg ops in fuseWithReshapeByExpansion pattern (#129128)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 3 11:20:33 PST 2025
Author: Nirvedh Meshram
Date: 2025-03-03T13:20:29-06:00
New Revision: 70b95d16645dfe1e8d76bdf94e791d74ad36e780
URL: https://github.com/llvm/llvm-project/commit/70b95d16645dfe1e8d76bdf94e791d74ad36e780
DIFF: https://github.com/llvm/llvm-project/commit/70b95d16645dfe1e8d76bdf94e791d74ad36e780.diff
LOG: [mlir][linalg] Retain Op Type of linalg ops in fuseWithReshapeByExpansion pattern (#129128)
This PR preserve linalg Op types for certain named ops such as Fill,
Copy and Transpose instead of fusion always resulting in a generic Op.
---------
Signed-off-by: Nirvedh Meshram <nirvedh at gmail.com>
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/test/Dialect/Linalg/reshape_fusion.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f4b6955823085..a45b5c43f5d33 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -815,6 +815,77 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
return success();
}
+// Create an expanded transpose op.
+static Operation *
+createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp,
+ SmallVector<ReassociationIndices> reassociation,
+ Value expandedInput, Value output) {
+ applyPermutationToVector(reassociation, transposeOp.getPermutation());
+ SmallVector<int64_t> newPerm;
+ for (auto reassoc : reassociation) {
+ for (auto dim : reassoc) {
+ newPerm.push_back(dim);
+ }
+ }
+ return rewriter.create<TransposeOp>(transposeOp.getLoc(), expandedInput,
+ output, newPerm);
+}
+
+// Create an expanded generic op.
+static Operation *createExpandedGenericOp(
+ PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
+ ArrayRef<Value> &expandedOpOperands, ArrayRef<Value> outputs,
+ ExpansionInfo &expansionInfo, ArrayRef<AffineMap> expandedOpIndexingMaps) {
+ // The iterator types of the expanded op are all parallel.
+ SmallVector<utils::IteratorType> iteratorTypes(
+ expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
+
+ for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
+ for (auto j : expansionInfo.getExpandedDims(i))
+ iteratorTypes[j] = type;
+
+ Operation *fused = rewriter.create<GenericOp>(
+ linalgOp.getLoc(), resultTypes, expandedOpOperands, outputs,
+ expandedOpIndexingMaps, iteratorTypes);
+
+ Region &fusedRegion = fused->getRegion(0);
+ Region &originalRegion = linalgOp->getRegion(0);
+ rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
+
+ // Update the index accesses after the expansion.
+ updateExpandedGenericOpRegion(rewriter, linalgOp.getLoc(), fusedRegion,
+ expansionInfo);
+
+ return fused;
+}
+
+// Create an expanded fused op that retains the name for certain ops
+// such as fill, copy and transpose and produce a generic op for
+// rest of linalg ops.
+static Operation *createExpandedOp(
+ PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
+ ArrayRef<Value> expandedOpOperands, ArrayRef<Value> outputs,
+ ArrayRef<AffineMap> expandedOpIndexingMaps, ExpansionInfo &expansionInfo,
+ SmallVector<ReassociationIndices> reassociation) {
+
+ return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
+ .Case<TransposeOp>([&](TransposeOp transposeOp) {
+ return createExpandedTransposeOp(rewriter, transposeOp, reassociation,
+ expandedOpOperands[0], outputs[0]);
+ })
+ .Case<FillOp, CopyOp>([&](Operation *op) {
+ return clone(rewriter, linalgOp, resultTypes,
+ llvm::to_vector(llvm::concat<Value>(
+ llvm::to_vector(expandedOpOperands),
+ llvm::to_vector(outputs))));
+ })
+ .Default([&](Operation *op) {
+ return createExpandedGenericOp(rewriter, linalgOp, resultTypes,
+ expandedOpOperands, outputs,
+ expansionInfo, expandedOpIndexingMaps);
+ });
+}
+
/// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
/// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
/// that those conditions have been satisfied.
@@ -919,25 +990,13 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
}
}
- // The iterator types of the expanded op are all parallel.
- SmallVector<utils::IteratorType> iteratorTypes(
- expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
- for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
- for (auto j : expansionInfo.getExpandedDims(i))
- iteratorTypes[j] = type;
-
TypeRange resultTypes = ValueRange(outputs).getTypes();
- auto fusedOp =
- rewriter.create<GenericOp>(linalgOp.getLoc(), resultTypes,
- /*inputs=*/expandedOpOperands, outputs,
- expandedOpIndexingMaps, iteratorTypes);
- Region &fusedRegion = fusedOp->getRegion(0);
- Region &originalRegion = linalgOp->getRegion(0);
- rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
-
- // Update the index accesses after the expansion.
- updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
-
+ SmallVector<ReassociationIndices> reassociationBeforeExpansion =
+ isExpanding ? expandingReshapeOp.getReassociationIndices()
+ : collapsingReshapeOp.getReassociationIndices();
+ Operation *fusedOp = createExpandedOp(
+ rewriter, linalgOp, resultTypes, expandedOpOperands, outputs,
+ expandedOpIndexingMaps, expansionInfo, reassociationBeforeExpansion);
// Reshape the result values to their original shape if this is a collapsing
// reshape folded into its consumer.
SmallVector<Value> resultVals;
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index ef853e4d662a7..4da9c0851ac70 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -783,9 +783,6 @@ func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
// -----
-#map0 = affine_map<(d0, d1, d2) -> (d2, d0)>
-#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
%arg1 : tensor<?x?xf32>,
%arg2 : tensor<?x?xf32>) ->
@@ -829,6 +826,73 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
// -----
+func.func @linalg_copy_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
+ %arg1 : tensor<?x?xf32>) ->
+ tensor<?x?xf32>
+{
+ %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
+ tensor<?x7x?x8xf32> into tensor<?x?xf32>
+ %1 = linalg.copy ins(%0 : tensor<?x?xf32>)
+ outs(%arg1 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+// CHECK: func @linalg_copy_reshape_producer_fusion
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK: %[[C8:.+]] = arith.constant 8 : index
+// CHECK: %[[C7:.+]] = arith.constant 7 : index
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
+// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C7]] : index
+// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C8]] : index
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_0]], 7, %[[VAL_1]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
+// CHECK: %[[T2:.+]] = linalg.copy
+// CHECK-SAME: ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
+// CHECK-SAME: outs(%[[T1]] : tensor<?x7x?x8xf32>)
+// CHECK: %[[T3:.+]] = tensor.collapse_shape %[[T2]]
+// CHECK-SAME: [0, 1], [2, 3]
+// CHECK-SAME: tensor<?x7x?x8xf32> into tensor<?x?xf32>
+// CHECK: return %[[T3]]
+
+// -----
+
+func.func @linalg_transpose_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
+ %arg1 : tensor<?x?xf32>) ->
+ tensor<?x?xf32>
+{
+ %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
+ tensor<?x7x?x8xf32> into tensor<?x?xf32>
+ %1 = linalg.transpose ins(%0 : tensor<?x?xf32>)
+ outs(%arg1 : tensor<?x?xf32>) permutation = [1, 0]
+ return %1 : tensor<?x?xf32>
+}
+
+// CHECK: func @linalg_transpose_reshape_producer_fusion
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
+// CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
+// CHECK-DAG: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+// CHECK-DAG: %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C7]] : index
+// CHECK-DAG: %[[VAL_1:.+]] = arith.divsi %[[DIM]], %[[C8]] : index
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 8, %[[VAL_0]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32>
+// CHECK: %[[T2:.+]] = linalg.transpose
+// CHECK-SAME: ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
+// CHECK-SAME: outs(%[[T1]] : tensor<?x8x?x7xf32>)
+// CHECK-SAME: permutation = [2, 3, 0, 1]
+// CHECK: %[[T3:.+]] = tensor.collapse_shape %[[T2]]
+// CHECK-SAME: [0, 1], [2, 3]
+// CHECK-SAME: tensor<?x8x?x7xf32> into tensor<?x?xf32>
+// CHECK: return %[[T3]]
+
+// -----
+
func.func @fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> {
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
%cst = arith.constant 0 : i32
More information about the Mlir-commits
mailing list