[Mlir-commits] [mlir] [mlir][Linalg] Fix linalg.generic iteration domain collapse for dynamic dims (PR #118208)
Kunwar Grover
llvmlistbot at llvm.org
Wed Feb 19 11:45:00 PST 2025
https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/118208
>From 593b5b88cb8c20ccdd4bd367135f6cb062b98a1a Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Sun, 1 Dec 2024 10:11:07 +0000
Subject: [PATCH 1/3] [mlir][Linalg] Fix linalg.generic iteration domain
collapse for dynamic dims
---
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 23 ++++---
.../fuse-with-reshape-by-collapsing.mlir | 63 +++++++++++++------
.../Dialect/Linalg/fusion-push-reshape.mlir | 7 +--
3 files changed, 62 insertions(+), 31 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 60cae77644291..3342c272f0f10 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1550,7 +1550,7 @@ static Value getCollapsedOpOperand(Location loc, LinalgOp op,
/// value in the collapsed operation.
void generateCollapsedIndexingRegion(Location loc, Block *block,
const CollapsingInfo &collapsingInfo,
- ValueRange loopRange,
+ ArrayRef<OpFoldResult> loopRange,
RewriterBase &rewriter) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointToStart(block);
@@ -1572,10 +1572,12 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
Value newIndexVal =
rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
+ Value loopDim =
+ getValueOrCreateConstantIndexOp(rewriter, loc, loopRange[dim]);
indexReplacementVals[dim] =
- rewriter.create<arith::RemSIOp>(loc, newIndexVal, loopRange[dim]);
+ rewriter.createOrFold<arith::RemSIOp>(loc, newIndexVal, loopDim);
newIndexVal =
- rewriter.create<arith::DivSIOp>(loc, newIndexVal, loopRange[dim]);
+ rewriter.createOrFold<arith::DivSIOp>(loc, newIndexVal, loopDim);
}
indexReplacementVals[foldedDims.value().front()] = newIndexVal;
}
@@ -1722,14 +1724,13 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
Location loc = op->getLoc();
+ SmallVector<OpFoldResult> loopBound =
+ llvm::map_to_vector(loopRanges, [&](Range range) { return range.size; });
+
if (collapsedOp.hasIndexSemantics()) {
// Collect the loop range of the generic op.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(collapsedOp);
- SmallVector<Value> loopBound =
- llvm::map_to_vector(loopRanges, [&](Range range) {
- return getValueOrCreateConstantIndexOp(rewriter, loc, range.size);
- });
generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
collapsingInfo, loopBound, rewriter);
}
@@ -1747,15 +1748,19 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
op.getIndexingMapMatchingResult(originalResult.value());
SmallVector<ReassociationIndices> reassociation =
getOperandReassociation(indexingMap, collapsingInfo);
+ SmallVector<OpFoldResult> resultShape =
+ applyPermutationMap(indexingMap, ArrayRef(loopBound));
Value result;
if (isa<MemRefType>(collapsedOpResult.getType())) {
MemRefType expandShapeResultType = MemRefType::get(
originalResultType.getShape(), originalResultType.getElementType());
result = rewriter.create<memref::ExpandShapeOp>(
- loc, expandShapeResultType, collapsedOpResult, reassociation);
+ loc, expandShapeResultType, collapsedOpResult, reassociation,
+ resultShape);
} else {
result = rewriter.create<tensor::ExpandShapeOp>(
- loc, originalResultType, collapsedOpResult, reassociation);
+ loc, originalResultType, collapsedOpResult, reassociation,
+ resultShape);
}
results.push_back(result);
} else {
diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index 7db997cd4c0b5..89734e7542801 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -225,6 +225,38 @@ func.func @fuse_by_collapsing_dynamic(%arg0 : tensor<?x?x?x?x?xi32>,
// -----
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @fuse_by_collapsing_dynamic_2(%arg0 : tensor<?xf32>, %sz0: index, %sz1: index) -> tensor<?x?xf32> {
+ %0 = tensor.expand_shape %arg0 [[0, 1]] output_shape [%sz0, %sz1] : tensor<?xf32> into tensor<?x?xf32>
+ %init = tensor.empty(%sz1, %sz0) : tensor<?x?xf32>
+ %1 = linalg.generic {
+ indexing_maps = [#map0, #map0],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%0 : tensor<?x?xf32>)
+ outs(%init : tensor<?x?xf32>) {
+ ^bb0(%b0 : f32, %b1 : f32):
+ %out = arith.negf %b0 : f32
+ linalg.yield %out : f32
+ } -> tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func @fuse_by_collapsing_dynamic_2
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
+// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[EXPANDED]], %[[C1]]
+// CHECK: %[[OUT:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf32>)
+// CHECK-SAME: outs(%{{.*}} : tensor<?xf32>)
+// CHECK: %[[EXPANDED_1:.+]] = tensor.expand_shape %[[OUT]]
+// CHECK-SAME: output_shape [%[[DIM0]], %[[DIM1]]]
+// CHECK: return %[[EXPANDED_1]]
+
+// -----
+
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3)>
func.func @fuse_reductions(%arg0 : tensor<2x?x5xf32>, %arg1 : tensor<2x5xf32>, %sz0: index) -> tensor<2x5xf32> {
@@ -425,10 +457,11 @@ func.func @fuse_only_one_reassociation(%arg0 : tensor<?x?xf32>, %arg1 : tensor<4
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK: func @fuse_only_one_reassociation
// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<4x?x?x8xf32>, %[[SZ0:.+]]: index, %[[SZ1:.+]]: index)
-// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [%[[SZ0]], 4, %[[SZ1]], 8]
+// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[EXPAND_ARG0]], %[[C0]] : tensor<?x4x?x8xf32>
+// CHECK-DAG: %[[DIM_2:.+]] = tensor.dim %[[EXPAND_ARG0]], %[[C2]] : tensor<?x4x?x8xf32>
// CHECK-DAG: %[[COLLAPSE_ARG0:.+]] = tensor.collapse_shape %[[EXPAND_ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
// CHECK-DAG: %[[COLLAPSE_ARG1_0:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
// CHECK-DAG: %[[COLLAPSE_ARG1_1:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
@@ -437,10 +470,7 @@ func.func @fuse_only_one_reassociation(%arg0 : tensor<?x?xf32>, %arg1 : tensor<4
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[COLLAPSE_ARG0]], %[[COLLAPSE_ARG1_0]] :
// CHECK-SAME: outs(%[[COLLAPSE_ARG1_1]] :
-// CHECK: %[[DIM:.+]] = tensor.dim %[[GENERIC]], %[[C1]] : tensor<4x?x?xf32>
-// CHECK: %[[DIM_2:.+]] = tensor.dim %[[GENERIC]], %[[C2]] : tensor<4x?x?xf32>
-// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_2]], %[[C8]] : index
-// CHECK: %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0], [1], [2, 3]] output_shape [4, %[[DIM]], %[[VAL_1]], 8] : tensor<4x?x?xf32> into tensor<4x?x?x8xf32>
+// CHECK: %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0], [1], [2, 3]] output_shape [4, %[[DIM]], %[[DIM_2]], 8] : tensor<4x?x?xf32> into tensor<4x?x?x8xf32>
// CHECK: return %[[EXPANDED_3]]
// -----
@@ -475,15 +505,16 @@ func.func @fold_non_consecutive_dims(%arg0 : tensor<?x?xi32>, %sz0: index, %sz1:
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1, d0)>
// CHECK: func @fold_non_consecutive_dims(
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>, %[[SZ0:.+]]: index, %[[SZ1:.+]]: index)
-// CHECK: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: %[[C4:.+]] = arith.constant 4 : index
-// CHECK: %[[C8:.+]] = arith.constant 8 : index
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
+// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 8] : tensor<?x?xi32> into tensor<?x4x?x8xi32>
-// CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
-// CHECK: %[[DIM_0:.+]] = tensor.dim %[[EXPANDED]], %[[C2]]
+// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
+// CHECK-DAG: %[[DIM_0:.+]] = tensor.dim %[[EXPANDED]], %[[C2]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM_0]], %[[DIM]])
+// CHECK-DAG: %[[DIM_1:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
+// CHECK-DAG: %[[DIM_2:.+]] = tensor.dim %[[EXPANDED]], %[[C2]]
// CHECK: %[[COLLAPSE_INIT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2, 3]{{\]}}
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
@@ -502,11 +533,7 @@ func.func @fold_non_consecutive_dims(%arg0 : tensor<?x?xi32>, %sz0: index, %sz1:
// CHECK-DAG: %[[T6:.+]] = arith.addi %[[T5]], %[[T3]]
// CHECK-DAG: %[[T7:.+]] = arith.index_cast %[[T6]]
// CHECK: linalg.yield %[[T7]]
-// CHECK: %[[DIM_1:.+]] = tensor.dim %[[GENERIC]], %[[C0]] : tensor<?x?xi32>
-// CHECK: %[[DIM_2:.+]] = tensor.dim %[[GENERIC]], %[[C1]] : tensor<?x?xi32>
-// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_1]], %[[C8]] : index
-// CHECK: %[[VAL_3:.+]] = arith.divsi %[[DIM_2]], %[[C4]] : index
-// CHECK: %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 8, %[[VAL_3]], 4] : tensor<?x?xi32> into tensor<?x8x?x4xi32>
+// CHECK: %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[DIM_2]], 8, %[[DIM_1]], 4] : tensor<?x?xi32> into tensor<?x8x?x4xi32>
// CHECK: return %[[EXPANDED_3]]
// -----
diff --git a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
index 7acbd843cd1e7..fd3c321722508 100644
--- a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
@@ -5,15 +5,14 @@
// CHECK-LABEL: func @reshape
// CHECK-SAME: (%[[A:.*]]: tensor<?x16xf32>, %[[B:.*]]: tensor<16xf32>, %[[INIT:.*]]: tensor<?x112x16xf32>, %[[SZ0:.*]]: index)
-// CHECK: %[[C112:.*]] = arith.constant 112 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[A]]
+// CHECK: %[[DIM:.*]] = tensor.dim %[[EXPANDED]], %[[C0]]
// CHECK: %[[RI:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] : tensor<?x112x16xf32> into tensor<?x16xf32>
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP2]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<?x16xf32>)
-// CHECK: %[[DIM:.*]] = tensor.dim %[[R]], %[[C0]] : tensor<?x16xf32>
-// CHECK: %[[VAL_1:.*]] = arith.divsi %[[DIM]], %[[C112]] : index
-// CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[\[}}0, 1], [2]] output_shape [%[[VAL_1]], 112, 16] : tensor<?x16xf32> into tensor<?x112x16xf32>
+// CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[\[}}0, 1], [2]] output_shape [%[[DIM]], 112, 16] : tensor<?x16xf32> into tensor<?x112x16xf32>
// CHECK: return %[[RR]] : tensor<?x112x16xf32>
func.func @reshape(%A: tensor<?x16xf32>, %B: tensor<16xf32>, %init: tensor<?x112x16xf32>, %sz0: index) -> tensor<?x112x16xf32> {
%0 = tensor.expand_shape %A [[0, 1], [2]] output_shape [%sz0, 112, 16]
>From 533294d12b84d4bd8d1b30938b936629de1333e4 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Sun, 9 Feb 2025 00:40:42 +0000
Subject: [PATCH 2/3] Address comments
---
.../Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp | 10 ++++++----
1 file changed, 6 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 3342c272f0f10..1ce29e8206d1f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1548,10 +1548,9 @@ static Value getCollapsedOpOperand(Location loc, LinalgOp op,
/// Modify the `linalg.index` operations in the original generic op, to its
/// value in the collapsed operation.
-void generateCollapsedIndexingRegion(Location loc, Block *block,
- const CollapsingInfo &collapsingInfo,
- ArrayRef<OpFoldResult> loopRange,
- RewriterBase &rewriter) {
+static void generateCollapsedIndexingRegion(
+ Location loc, Block *block, const CollapsingInfo &collapsingInfo,
+ ArrayRef<OpFoldResult> loopRange, RewriterBase &rewriter) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointToStart(block);
@@ -1748,6 +1747,9 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
op.getIndexingMapMatchingResult(originalResult.value());
SmallVector<ReassociationIndices> reassociation =
getOperandReassociation(indexingMap, collapsingInfo);
+ assert(
+ indexingMap.isProjectedPermutation() &&
+ "Expected indexing map to be a projected permutation for collapsing");
SmallVector<OpFoldResult> resultShape =
applyPermutationMap(indexingMap, ArrayRef(loopBound));
Value result;
>From 5e764ec5b68700f4b31f59f4456478c8e5e20a5c Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Thu, 20 Feb 2025 01:14:52 +0530
Subject: [PATCH 3/3] Update
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 1ce29e8206d1f..f4b6955823085 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1724,7 +1724,7 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
Location loc = op->getLoc();
SmallVector<OpFoldResult> loopBound =
- llvm::map_to_vector(loopRanges, [&](Range range) { return range.size; });
+ llvm::map_to_vector(loopRanges, [](Range range) { return range.size; });
if (collapsedOp.hasIndexSemantics()) {
// Collect the loop range of the generic op.
More information about the Mlir-commits
mailing list