[Mlir-commits] [mlir] c9ff839 - [mlir][Linalg] Fix linalg.generic iteration domain collapse for dynamic dims (#118208)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 19 12:02:48 PST 2025


Author: Kunwar Grover
Date: 2025-02-20T01:32:44+05:30
New Revision: c9ff8399647cd15cdb9f8853b45854920de17162

URL: https://github.com/llvm/llvm-project/commit/c9ff8399647cd15cdb9f8853b45854920de17162
DIFF: https://github.com/llvm/llvm-project/commit/c9ff8399647cd15cdb9f8853b45854920de17162.diff

LOG: [mlir][Linalg] Fix linalg.generic iteration domain collapse for dynamic dims (#118208)

This pr fixes how iteration domain of linalg.generic is collapsed when
fusing with tensor.expand_shape. Previously, the output_shape for
tensor.expand shape was infered, which doesn't always work except some
special cases.

This patch makes the logic explicitly set the bounds of the new
collapsed iteration domain, because we already know them.

---------

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
    mlir/test/Dialect/Linalg/fusion-push-reshape.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 60cae77644291..f4b6955823085 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1548,10 +1548,9 @@ static Value getCollapsedOpOperand(Location loc, LinalgOp op,
 
 /// Modify the `linalg.index` operations in the original generic op, to its
 /// value in the collapsed operation.
-void generateCollapsedIndexingRegion(Location loc, Block *block,
-                                     const CollapsingInfo &collapsingInfo,
-                                     ValueRange loopRange,
-                                     RewriterBase &rewriter) {
+static void generateCollapsedIndexingRegion(
+    Location loc, Block *block, const CollapsingInfo &collapsingInfo,
+    ArrayRef<OpFoldResult> loopRange, RewriterBase &rewriter) {
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPointToStart(block);
 
@@ -1572,10 +1571,12 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
     Value newIndexVal =
         rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
     for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
+      Value loopDim =
+          getValueOrCreateConstantIndexOp(rewriter, loc, loopRange[dim]);
       indexReplacementVals[dim] =
-          rewriter.create<arith::RemSIOp>(loc, newIndexVal, loopRange[dim]);
+          rewriter.createOrFold<arith::RemSIOp>(loc, newIndexVal, loopDim);
       newIndexVal =
-          rewriter.create<arith::DivSIOp>(loc, newIndexVal, loopRange[dim]);
+          rewriter.createOrFold<arith::DivSIOp>(loc, newIndexVal, loopDim);
     }
     indexReplacementVals[foldedDims.value().front()] = newIndexVal;
   }
@@ -1722,14 +1723,13 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
   LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
 
   Location loc = op->getLoc();
+  SmallVector<OpFoldResult> loopBound =
+      llvm::map_to_vector(loopRanges, [](Range range) { return range.size; });
+
   if (collapsedOp.hasIndexSemantics()) {
     // Collect the loop range of the generic op.
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPoint(collapsedOp);
-    SmallVector<Value> loopBound =
-        llvm::map_to_vector(loopRanges, [&](Range range) {
-          return getValueOrCreateConstantIndexOp(rewriter, loc, range.size);
-        });
     generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(),
                                     collapsingInfo, loopBound, rewriter);
   }
@@ -1747,15 +1747,22 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
           op.getIndexingMapMatchingResult(originalResult.value());
       SmallVector<ReassociationIndices> reassociation =
           getOperandReassociation(indexingMap, collapsingInfo);
+      assert(
+          indexingMap.isProjectedPermutation() &&
+          "Expected indexing map to be a projected permutation for collapsing");
+      SmallVector<OpFoldResult> resultShape =
+          applyPermutationMap(indexingMap, ArrayRef(loopBound));
       Value result;
       if (isa<MemRefType>(collapsedOpResult.getType())) {
         MemRefType expandShapeResultType = MemRefType::get(
             originalResultType.getShape(), originalResultType.getElementType());
         result = rewriter.create<memref::ExpandShapeOp>(
-            loc, expandShapeResultType, collapsedOpResult, reassociation);
+            loc, expandShapeResultType, collapsedOpResult, reassociation,
+            resultShape);
       } else {
         result = rewriter.create<tensor::ExpandShapeOp>(
-            loc, originalResultType, collapsedOpResult, reassociation);
+            loc, originalResultType, collapsedOpResult, reassociation,
+            resultShape);
       }
       results.push_back(result);
     } else {

diff  --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index 7db997cd4c0b5..89734e7542801 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -225,6 +225,38 @@ func.func @fuse_by_collapsing_dynamic(%arg0 : tensor<?x?x?x?x?xi32>,
 
 // -----
 
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @fuse_by_collapsing_dynamic_2(%arg0 : tensor<?xf32>, %sz0: index, %sz1: index) -> tensor<?x?xf32> {
+  %0 = tensor.expand_shape %arg0 [[0, 1]] output_shape [%sz0, %sz1] : tensor<?xf32> into tensor<?x?xf32>
+  %init = tensor.empty(%sz1, %sz0) : tensor<?x?xf32>
+  %1 = linalg.generic {
+      indexing_maps = [#map0, #map0],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%0 : tensor<?x?xf32>) 
+      outs(%init : tensor<?x?xf32>) {
+        ^bb0(%b0 : f32, %b1 : f32):
+          %out = arith.negf %b0 : f32
+          linalg.yield %out : f32
+      } -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func @fuse_by_collapsing_dynamic_2
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK:     %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
+// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[EXPANDED]], %[[C1]]
+// CHECK:     %[[OUT:.+]] = linalg.generic
+// CHECK-SAME:   ins(%[[ARG0]] : tensor<?xf32>)
+// CHECK-SAME:   outs(%{{.*}} : tensor<?xf32>)
+// CHECK:     %[[EXPANDED_1:.+]] = tensor.expand_shape %[[OUT]]
+// CHECK-SAME:    output_shape [%[[DIM0]], %[[DIM1]]]
+// CHECK:      return %[[EXPANDED_1]]
+
+// -----
+
 #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3)>
 func.func @fuse_reductions(%arg0 : tensor<2x?x5xf32>, %arg1 : tensor<2x5xf32>, %sz0: index) -> tensor<2x5xf32> {
@@ -425,10 +457,11 @@ func.func @fuse_only_one_reassociation(%arg0 : tensor<?x?xf32>, %arg1 : tensor<4
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 //      CHECK: func @fuse_only_one_reassociation
 // CHECK-SAME:     (%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<4x?x?x8xf32>, %[[SZ0:.+]]: index, %[[SZ1:.+]]: index)
-//  CHECK-DAG:   %[[C8:.*]] = arith.constant 8 : index
 //  CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
-//  CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //  CHECK-DAG:   %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [%[[SZ0]], 4, %[[SZ1]], 8]
+//  CHECK-DAG:   %[[DIM:.+]] = tensor.dim %[[EXPAND_ARG0]], %[[C0]] : tensor<?x4x?x8xf32>
+//  CHECK-DAG:   %[[DIM_2:.+]] = tensor.dim %[[EXPAND_ARG0]], %[[C2]] : tensor<?x4x?x8xf32>
 //  CHECK-DAG:   %[[COLLAPSE_ARG0:.+]] = tensor.collapse_shape %[[EXPAND_ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
 //  CHECK-DAG:   %[[COLLAPSE_ARG1_0:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
 //  CHECK-DAG:   %[[COLLAPSE_ARG1_1:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
@@ -437,10 +470,7 @@ func.func @fuse_only_one_reassociation(%arg0 : tensor<?x?xf32>, %arg1 : tensor<4
 // CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel"]
 // CHECK-SAME:       ins(%[[COLLAPSE_ARG0]], %[[COLLAPSE_ARG1_0]] :
 // CHECK-SAME:       outs(%[[COLLAPSE_ARG1_1]] :
-//      CHECK:   %[[DIM:.+]] = tensor.dim %[[GENERIC]], %[[C1]] : tensor<4x?x?xf32>
-//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[GENERIC]], %[[C2]] : tensor<4x?x?xf32>
-//      CHECK:   %[[VAL_1:.+]] = arith.divsi %[[DIM_2]], %[[C8]] : index
-//      CHECK:   %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0], [1], [2, 3]] output_shape [4, %[[DIM]], %[[VAL_1]], 8] : tensor<4x?x?xf32> into tensor<4x?x?x8xf32>
+//      CHECK:   %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0], [1], [2, 3]] output_shape [4, %[[DIM]], %[[DIM_2]], 8] : tensor<4x?x?xf32> into tensor<4x?x?x8xf32>
 //      CHECK:   return %[[EXPANDED_3]]
 
 // -----
@@ -475,15 +505,16 @@ func.func @fold_non_consecutive_dims(%arg0 : tensor<?x?xi32>, %sz0: index, %sz1:
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1, d0)>
 //      CHECK: func @fold_non_consecutive_dims(
 // CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xi32>, %[[SZ0:.+]]: index, %[[SZ1:.+]]: index)
-//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
-//      CHECK:   %[[C4:.+]] = arith.constant 4 : index
-//      CHECK:   %[[C8:.+]] = arith.constant 8 : index
-//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
-//      CHECK:   %[[C2:.+]] = arith.constant 2 : index
+//      CHECK-DAG:   %[[C4:.+]] = arith.constant 4 : index
+//      CHECK-DAG:   %[[C8:.+]] = arith.constant 8 : index
+//      CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//      CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
 //      CHECK:   %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 8] : tensor<?x?xi32> into tensor<?x4x?x8xi32>
-//      CHECK:   %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
-//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[EXPANDED]], %[[C2]]
+//      CHECK-DAG:   %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
+//      CHECK-DAG:   %[[DIM_0:.+]] = tensor.dim %[[EXPANDED]], %[[C2]]
 //      CHECK:   %[[INIT:.+]] = tensor.empty(%[[DIM_0]], %[[DIM]])
+//      CHECK-DAG:   %[[DIM_1:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
+//      CHECK-DAG:   %[[DIM_2:.+]] = tensor.dim %[[EXPANDED]], %[[C2]]
 //      CHECK:   %[[COLLAPSE_INIT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2, 3]{{\]}}
 //      CHECK:   %[[GENERIC:.+]] = linalg.generic
 // CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]]]
@@ -502,11 +533,7 @@ func.func @fold_non_consecutive_dims(%arg0 : tensor<?x?xi32>, %sz0: index, %sz1:
 //  CHECK-DAG:       %[[T6:.+]] = arith.addi %[[T5]], %[[T3]]
 //  CHECK-DAG:       %[[T7:.+]] = arith.index_cast %[[T6]]
 //      CHECK:       linalg.yield %[[T7]]
-//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[GENERIC]], %[[C0]] : tensor<?x?xi32>
-//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[GENERIC]], %[[C1]] : tensor<?x?xi32>
-//      CHECK:   %[[VAL_2:.+]] = arith.divsi %[[DIM_1]], %[[C8]] : index
-//      CHECK:   %[[VAL_3:.+]] = arith.divsi %[[DIM_2]], %[[C4]] : index
-//      CHECK:   %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 8, %[[VAL_3]], 4] : tensor<?x?xi32> into tensor<?x8x?x4xi32>
+//      CHECK:   %[[EXPANDED_3:.+]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[DIM_2]], 8, %[[DIM_1]], 4] : tensor<?x?xi32> into tensor<?x8x?x4xi32>
 //      CHECK:   return %[[EXPANDED_3]]
 
 // -----

diff  --git a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
index 7acbd843cd1e7..fd3c321722508 100644
--- a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir
@@ -5,15 +5,14 @@
 
 // CHECK-LABEL: func @reshape
 // CHECK-SAME: (%[[A:.*]]: tensor<?x16xf32>, %[[B:.*]]: tensor<16xf32>, %[[INIT:.*]]: tensor<?x112x16xf32>, %[[SZ0:.*]]: index)
-//      CHECK: %[[C112:.*]] = arith.constant 112 : index
 //      CHECK: %[[C0:.*]] = arith.constant 0 : index
+//      CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[A]]
+//      CHECK: %[[DIM:.*]] = tensor.dim %[[EXPANDED]], %[[C0]]
 //      CHECK: %[[RI:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] : tensor<?x112x16xf32> into tensor<?x16xf32>
 //      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP2]]],
 // CHECK-SAME: iterator_types = ["parallel", "parallel"]}
 // CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<?x16xf32>)
-//      CHECK: %[[DIM:.*]] = tensor.dim %[[R]], %[[C0]] : tensor<?x16xf32>
-//      CHECK: %[[VAL_1:.*]] = arith.divsi %[[DIM]], %[[C112]] : index
-//      CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[\[}}0, 1], [2]] output_shape [%[[VAL_1]], 112, 16] : tensor<?x16xf32> into tensor<?x112x16xf32>
+//      CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[\[}}0, 1], [2]] output_shape [%[[DIM]], 112, 16] : tensor<?x16xf32> into tensor<?x112x16xf32>
 //      CHECK: return %[[RR]] : tensor<?x112x16xf32>
 func.func @reshape(%A: tensor<?x16xf32>, %B: tensor<16xf32>, %init: tensor<?x112x16xf32>, %sz0: index) -> tensor<?x112x16xf32> {
   %0 = tensor.expand_shape %A [[0, 1], [2]] output_shape [%sz0, 112, 16]


        


More information about the Mlir-commits mailing list