[Mlir-commits] [mlir] [mlir][Linalg] Fix linalg.generic iteration domain collapse for dynamic dims (PR #118208)

Kunwar Grover llvmlistbot at llvm.org
Sat Feb 8 16:41:17 PST 2025


https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/118208

>From 593b5b88cb8c20ccdd4bd367135f6cb062b98a1a Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Sun, 1 Dec 2024 10:11:07 +0000
Subject: [PATCH 1/2] [mlir][Linalg] Fix linalg.generic iteration domain
 collapse for dynamic dims

---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 23 ++++---
 .../fuse-with-reshape-by-collapsing.mlir      | 63 +++++++++++++------
 .../Dialect/Linalg/fusion-push-reshape.mlir   |  7 +--
 3 files changed, 62 insertions(+), 31 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 60cae776442915a..3342c272f0f1062 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1550,7 +1550,7 @@ static Value getCollapsedOpOperand(Location loc, LinalgOp op,
 /// value in the collapsed operation.
 void generateCollapsedIndexingRegion(Location loc, Block *block,
                                      const CollapsingInfo &collapsingInfo,
-                                     ValueRange loopRange,
+                                     ArrayRef<OpFoldResult> loopRange,
                                      RewriterBase &rewriter) {
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPointToStart(block);
@@ -1572,10 +1572,12 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
     Value newIndexVal =
         rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
     for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
+      Value loopDim =
+          getValueOrCreateConstantIndexOp(rewriter, loc, loopRange[dim]);
       indexReplacementVals[dim] =
-          rewriter.create<arith::RemSIOp>(loc, newIndexVal, loopRange[dim]);
+          rewriter.createOrFold<arith::RemSIOp>(loc, newIndexVal, loopDim);
       newIndexVal =
-          rewriter.create<arith::DivSIOp>(loc, newIndexVal, loopRange[dim]);
+          rewriter.createOrFold<arith::DivSIOp>(loc, newIndexVal, loopDim);
     }
     indexReplacementVals[foldedDims.value().front()] = newIndexVal;
   }
@@ -1722,14 +1724,13 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
   LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
 
   Location loc = op->getLoc();
+  SmallVector<OpFoldResult> loopBound =
+      llvm::map_to_vector(loopRanges, [&](Range range) { return range.size; });
+
   if (collapsedOp.hasIndexSemantics()) {
     // Collect the loop range of the generic op.
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPoint(collapsedOp);
-    SmallVector<Value> loopBound =
-        llvm::map_to_vector(loopRanges, [&](Range range) {
-          return getValueOrCreateConstantIndexOp(rewriter, loc, range.size);
-        });
     generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
                                     collapsingInfo, loopBound, rewriter);
   }
@@ -1747,15 +1748,19 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
           op.getIndexingMapMatchingResult(originalResult.value());
       SmallVector<ReassociationIndices> reassociation =
           getOperandReassociation(indexingMap, collapsingInfo);
+      SmallVector<OpFoldResult> resultShape =
+          applyPermutationMap(indexingMap, ArrayRef(loopBound));
       Value result;
       if (isa<MemRefType>(collapsedOpResult.getType())) {
         MemRefType expandShapeResultType = MemRefType::get(
             originalResultType.getShape(), originalResultType.getElementType());
         result = rewriter.create<memref::ExpandShapeOp>(
-            loc, expandShapeResultType, collapsedOpResult, reassociation);
+            loc, expandShapeResultType, collapsedOpResult, reassociation,
+            resultShape);
       } else {
         result = rewriter.create<tensor::ExpandShapeOp>(
-            loc, originalResultType, collapsedOpResult, reassociation);
+            loc, originalResultType, collapsedOpResult, reassociation,
+            resultShape);
       }
       results.push_back(result);
     } else {
diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index 7db997cd4c0b5f8..89734e7542801d0 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -225,6 +225,38 @@ func.func @fuse_by_collapsing_dynamic(%arg0 : tensor<?x?x?x?x?xi32>,
 
 // -----
 
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @fuse_by_collapsing_dynamic_2(%arg0 : tensor<?xf32>, %sz0: index, %sz1: index) -> tensor<?x?xf32> {
+  %0 = tensor.expand_shape %arg0 [[0, 1]] output_shape [%sz0, %sz1] : tensor<?xf32> into tensor<?x?xf32>
+  %init = tensor.empty(%sz1, %sz0) : tensor<?x?xf32>
+  %1 = linalg.generic {
+      indexing_maps = [#map0, #map0],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%0 : tensor<?x?xf32>) 
+      outs(%init : tensor<?x?xf32>) {
+        ^bb0(%b0 : f32, %b1 : f32):
+          %out = arith.negf %b0 : f32
+          linalg.yield %out : f32
+      } -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func @fuse_by_collapsing_dynamic_2
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK:     %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
+// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[EXPANDED]], %[[C1]]
+// CHECK:     %[[OUT:.+]] = linalg.generic
+// CHECK-SAME:   ins(%[[ARG0]] : tensor<?xf32>)
+// CHECK-SAME:   outs(%{{.*}} : tensor<?xf32>)
+// CHECK:     %[[EXPANDED_1:.+]] = tensor.expand_shape %[[OUT]]
+// CHECK-SAME:    output_shape [%[[DIM0]], %[[DIM1]]]
+// CHECK:      return %[[EXPANDED_1]]
+
+// -----
+
 #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3)>
 func.func @fuse_reductions(%arg0 : tensor<2x?x5xf32>, %arg1 : tensor<2x5xf32>, %sz0: index) -> tensor<2x5xf32> {
@@ -425,10 +457,11 @@ func.func @fuse_only_one_reassociation(%arg0 : tensor<?x?xf32>, %arg1 : tensor<4
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 //      CHECK: func @fuse_only_one_reassociation
 // CHECK-SAME:     (%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<4x?x?x8xf32>, %[[SZ0:.+]]: index, %[[SZ1:.+]]: index)
-//  CHECK-DAG:   %[[C8:.*]] = arith.constant 8 : index
 //  CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
-//  CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //  CHECK-DAG:   %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [%[[SZ0]], 4, %[[SZ1]], 8]
+//  CHECK-DAG:   %[[DIM:.+]] = tensor.dim %[[EXPAND_ARG0]], %[[C0]] : tensor<?x4x?x8xf32>
+//  CHECK-DAG:   %[[DIM_2:.+]] = tensor.dim %[[EXPAND_ARG0]], %[[C2]] : tensor<?x4x?x8xf32>
 //  CHECK-DAG:   %[[COLLAPSE_ARG0:.+]] = tensor.collapse_shape %[[EXPAND_ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
 //  CHECK-DAG:   %[[COLLAPSE_ARG1_0:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
 //  CHECK-DAG:   %[[COLLAPSE_ARG1_1:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
@@ -437,10 +470,7 @@ func.func @fuse_only_one_reassociation(%arg0 : tensor<?x?xf32>, %arg1 : tensor<4
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel"]
 // CHECK-SAME:       ins(%[[COLLAPSE_ARG0]], %[[COLLAPSE_ARG1_0]] :
 // CHECK-SAME:       outs(%[[COLLAPSE_ARG1_1]] :
-//      CHECK:   %[[DIM:.+]] = tensor.dim %[[GENERIC]], %[[C1]] : tensor<4x?x?xf32>
-//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[GENERIC]], %[[C2]] : tensor<4x?x?xf32>
-//      CHECK:   %[[VAL_1:.+]] = arith.divsi %[[DIM_2]], %[[C8]] : index
-//      CHECK:   %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0], [1], [2, 3]] output_shape [4, %[[DIM]], %[[VAL_1]], 8] : tensor<4x?x?xf32> into tensor<4x?x?x8xf32>
+//      CHECK:   %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0], [1], [2, 3]] output_shape [4, %[[DIM]], %[[DIM_2]], 8] : tensor<4x?x?xf32> into tensor<4x?x?x8xf32>
 //      CHECK:   return %[[EXPANDED_3]]
 
 // -----
@@ -475,15 +505,16 @@ func.func @fold_non_consecutive_dims(%arg0 : tensor<?x?xi32>, %sz0: index, %sz1:
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1, d0)>
 //      CHECK: func @fold_non_consecutive_dims(
 // CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xi32>, %[[SZ0:.+]]: index, %[[SZ1:.+]]: index)
-//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
-//      CHECK:   %[[C4:.+]] = arith.constant 4 : index
-//      CHECK:   %[[C8:.+]] = arith.constant 8 : index
-//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
-//      CHECK:   %[[C2:.+]] = arith.constant 2 : index
+//      CHECK-DAG:   %[[C4:.+]] = arith.constant 4 : index
+//      CHECK-DAG:   %[[C8:.+]] = arith.constant 8 : index
+//      CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//      CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
 //      CHECK:   %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 8] : tensor<?x?xi32> into tensor<?x4x?x8xi32>
-//      CHECK:   %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
-//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[EXPANDED]], %[[C2]]
+//      CHECK-DAG:   %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
+//      CHECK-DAG:   %[[DIM_0:.+]] = tensor.dim %[[EXPANDED]], %[[C2]]
 //      CHECK:   %[[INIT:.+]] = tensor.empty(%[[DIM_0]], %[[DIM]])
+//      CHECK-DAG:   %[[DIM_1:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
+//      CHECK-DAG:   %[[DIM_2:.+]] = tensor.dim %[[EXPANDED]], %[[C2]]
 //      CHECK:   %[[COLLAPSE_INIT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2, 3]{{\]}}
 //      CHECK:   %[[GENERIC:.+]] = linalg.generic
 // CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]]]
@@ -502,11 +533,7 @@ func.func @fold_non_consecutive_dims(%arg0 : tensor<?x?xi32>, %sz0: index, %sz1:
 //  CHECK-DAG:       %[[T6:.+]] = arith.addi %[[T5]], %[[T3]]
 //  CHECK-DAG:       %[[T7:.+]] = arith.index_cast %[[T6]]
 //      CHECK:       linalg.yield %[[T7]]
-//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[GENERIC]], %[[C0]] : tensor<?x?xi32>
-//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[GENERIC]], %[[C1]] : tensor<?x?xi32>
-//      CHECK:   %[[VAL_2:.+]] = arith.divsi %[[DIM_1]], %[[C8]] : index
-//      CHECK:   %[[VAL_3:.+]] = arith.divsi %[[DIM_2]], %[[C4]] : index
-//      CHECK:   %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 8, %[[VAL_3]], 4] : tensor<?x?xi32> into tensor<?x8x?x4xi32>
+//      CHECK:   %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[DIM_2]], 8, %[[DIM_1]], 4] : tensor<?x?xi32> into tensor<?x8x?x4xi32>
 //      CHECK:   return %[[EXPANDED_3]]
 
 // -----
diff --git a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
index 7acbd843cd1e7c9..fd3c32172250862 100644
--- a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
@@ -5,15 +5,14 @@
 
 // CHECK-LABEL: func @reshape
 // CHECK-SAME: (%[[A:.*]]: tensor<?x16xf32>, %[[B:.*]]: tensor<16xf32>, %[[INIT:.*]]: tensor<?x112x16xf32>, %[[SZ0:.*]]: index)
-//      CHECK: %[[C112:.*]] = arith.constant 112 : index
 //      CHECK: %[[C0:.*]] = arith.constant 0 : index
+//      CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[A]]
+//      CHECK: %[[DIM:.*]] = tensor.dim %[[EXPANDED]], %[[C0]]
 //      CHECK: %[[RI:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] : tensor<?x112x16xf32> into tensor<?x16xf32>
 //      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP2]]],
 // CHECK-SAME: iterator_types = ["parallel", "parallel"]}
 // CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<?x16xf32>)
-//      CHECK: %[[DIM:.*]] = tensor.dim %[[R]], %[[C0]] : tensor<?x16xf32>
-//      CHECK: %[[VAL_1:.*]] = arith.divsi %[[DIM]], %[[C112]] : index
-//      CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[\[}}0, 1], [2]] output_shape [%[[VAL_1]], 112, 16] : tensor<?x16xf32> into tensor<?x112x16xf32>
+//      CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[\[}}0, 1], [2]] output_shape [%[[DIM]], 112, 16] : tensor<?x16xf32> into tensor<?x112x16xf32>
 //      CHECK: return %[[RR]] : tensor<?x112x16xf32>
 func.func @reshape(%A: tensor<?x16xf32>, %B: tensor<16xf32>, %init: tensor<?x112x16xf32>, %sz0: index) -> tensor<?x112x16xf32> {
   %0 = tensor.expand_shape %A [[0, 1], [2]] output_shape [%sz0, 112, 16]

>From 533294d12b84d4bd8d1b30938b936629de1333e4 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Sun, 9 Feb 2025 00:40:42 +0000
Subject: [PATCH 2/2] Address comments

---
 .../Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp  | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 3342c272f0f1062..1ce29e8206d1f1a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1548,10 +1548,9 @@ static Value getCollapsedOpOperand(Location loc, LinalgOp op,
 
 /// Modify the `linalg.index` operations in the original generic op, to its
 /// value in the collapsed operation.
-void generateCollapsedIndexingRegion(Location loc, Block *block,
-                                     const CollapsingInfo &collapsingInfo,
-                                     ArrayRef<OpFoldResult> loopRange,
-                                     RewriterBase &rewriter) {
+static void generateCollapsedIndexingRegion(
+    Location loc, Block *block, const CollapsingInfo &collapsingInfo,
+    ArrayRef<OpFoldResult> loopRange, RewriterBase &rewriter) {
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPointToStart(block);
 
@@ -1748,6 +1747,9 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
           op.getIndexingMapMatchingResult(originalResult.value());
       SmallVector<ReassociationIndices> reassociation =
           getOperandReassociation(indexingMap, collapsingInfo);
+      assert(
+          indexingMap.isProjectedPermutation() &&
+          "Expected indexing map to be a projected permutation for collapsing");
       SmallVector<OpFoldResult> resultShape =
           applyPermutationMap(indexingMap, ArrayRef(loopBound));
       Value result;



More information about the Mlir-commits mailing list