[Mlir-commits] [mlir] [MLIR] Add reshape propagation through tensor.pad (PR #136681)

Hyunsung Lee llvmlistbot at llvm.org
Tue Apr 22 03:28:30 PDT 2025


https://github.com/ita9naiwa updated https://github.com/llvm/llvm-project/pull/136681

>From 1b4634142b861a60dfae0c4e16e2caffb1d150a0 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Tue, 22 Apr 2025 19:04:02 +0900
Subject: [PATCH 1/3] Add FoldReshapeWithProducerPadOpByExpansion

---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 83 ++++++++++++++++++-
 1 file changed, 82 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index bf70597d5ddfe..4e8af2bf46014 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -9,8 +9,8 @@
 // This file implements the linalg dialect Fusion on tensors operations pass.
 //
 //===----------------------------------------------------------------------===//
-
 #include "mlir/Dialect/Linalg/Passes.h"
+#include <iostream>
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -21,6 +21,7 @@
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
@@ -1101,6 +1102,84 @@ class FoldPadWithProducerReshapeOpByExpansion
   ControlFusionFn controlFoldingReshapes;
 };
 
+/// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
+/// by bubbling the expand_shape before the pad.
+struct FoldReshapeWithProducerPadOpByExpansion
+    : public OpRewritePattern<tensor::ExpandShapeOp> {
+
+  FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
+                                          ControlFusionFn foldReshapes,
+                                          PatternBenefit benefit = 1)
+      : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+        controlFoldingReshapes(std::move(foldReshapes)) {}
+
+  LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+                                PatternRewriter &rewriter) const override {
+    tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
+    if (!padOp)
+      return failure();
+
+    if (!padOp->hasOneUse())
+      return failure();
+
+    if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
+      return rewriter.notifyMatchFailure(expandOp,
+                                         "fusion blocked by control function");
+    }
+
+    // return failure if padOp has *any* dynamic padding
+    if (!padOp.getLow().empty() || !padOp.getHigh().empty()) {
+      return failure();
+    }
+
+    SmallVector<ReassociationIndices> reassociations =
+        expandOp.getReassociationIndices();
+    ArrayRef<int64_t> low = padOp.getStaticLow();
+    ArrayRef<int64_t> high = padOp.getStaticHigh();
+
+    for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
+      if (reInd.size() != 1 && (l != 0 || h != 0))
+        return failure();
+    }
+
+    SmallVector<OpFoldResult> newLow, newHigh;
+    for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+      for (size_t i = 0; i < reInd.size(); ++i) {
+        newLow.push_back(padOp.getMixedLowPad()[idx]);
+        newHigh.push_back(padOp.getMixedHighPad()[idx]);
+      }
+    }
+
+    // Calculate expanded shape manually
+    auto reshapeType = cast<RankedTensorType>(expandOp.getType());
+    ArrayRef<int64_t> finalShape = reshapeType.getShape();
+    SmallVector<int64_t> expandedShape;
+    for (auto [inDimIdx, outGroup] : llvm::enumerate(reassociations)) {
+      for (auto outDimIdx : outGroup) {
+        int64_t sz = finalShape[outDimIdx] - low[inDimIdx] - high[inDimIdx];
+        expandedShape.push_back(sz);
+      }
+    }
+
+    // Apply the reshape to the pad's source first
+    Location loc = expandOp.getLoc();
+    Value expandedSrc = rewriter.create<tensor::ExpandShapeOp>(
+        loc,
+        RankedTensorType::get(expandedShape,
+                              padOp.getSourceType().getElementType()),
+        padOp.getSource(), reassociations);
+
+    auto newPadOp = rewriter.create<tensor::PadOp>(
+        loc, expandOp.getType(), expandedSrc, newLow, newHigh,
+        padOp.getConstantPaddingValue(), padOp.getNofold());
+    rewriter.replaceOp(expandOp, newPadOp.getResult());
+    return success();
+  }
+
+private:
+  ControlFusionFn controlFoldingReshapes;
+};
+
 /// Pattern to fold a tensor.expand_shape op with its producer generic op
 /// by expanding the dimensionality of the loop in the producer op.
 struct FoldReshapeWithGenericOpByExpansion
@@ -2249,6 +2328,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
                                                     controlFoldingReshapes);
   patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
                                                         controlFoldingReshapes);
+  patterns.add<FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext(),
+                                                        controlFoldingReshapes);
   patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
                                                      controlFoldingReshapes);
 }

>From 333c4e0e847082937eed602c6503b9d1e27fe844 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Tue, 22 Apr 2025 19:04:14 +0900
Subject: [PATCH 2/3] Add test

---
 mlir/test/Dialect/Linalg/reshape_fusion.mlir | 27 ++++++++++++++++++--
 1 file changed, 25 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 67b4f2b32bad5..d6eafce6d049f 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -247,7 +247,7 @@ func.func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
 
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
-                                         %arg1 : tensor<?x?xi32>, 
+                                         %arg1 : tensor<?x?xi32>,
                                          %sz0: index, %sz1: index) ->
                                          tensor<?x?x4x5xi32>
 {
@@ -515,7 +515,7 @@ func.func @fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
 // -----
 
 func.func @reshape_as_consumer_permutation_with_multiple_results
-  (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index, 
+  (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index,
    %sz1: index, %sz2: index, %sz3: index, %sz4: index)
     -> (tensor<?x2x?x3x4x?xf32>, tensor<?x?x2x3x4x?xf32>) {
   %c:2 = linalg.generic {
@@ -893,3 +893,26 @@ func.func @move_operand_deps(%arg0 : tensor<?x128xf16>,
 //      CHECK:   %[[GENERIC:.+]] = linalg.generic
 // CHECK-SAME:       ins(%[[EXPANDED]] :
 //      CHECK:   return %[[GENERIC]]
+
+// -----
+
+func.func @fold_tensor_pad_with_expand(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0   = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
+  %padded = tensor.pad %0 low[0, 1, 1] high[0, 1, 1] {
+    ^bb0(%i: index, %j: index, %k: index):
+      tensor.yield %cst : f32
+  } : tensor<512x256x256xf32> to tensor<512x258x258xf32>
+  %expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32>
+  return %expanded : tensor<32x16x258x258xf32>
+}
+//      CHECK: func @fold_tensor_pad_with_expand(
+// CHECK-SAME:     %[[ARG0:[^:]+]]: tensor<512x256x256xf32>
+//  CHECK-DAG:   %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+//  CHECK-DAG:   %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
+//      CHECK:   %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXPANDED]] : tensor<32x16x256x256xf32>)
+//      CHECK:   %[[PADDED:.*]] = tensor.pad %[[FILLED]] low[0, 0, 1, 1] high[0, 0, 1, 1]
+//      CHECK:   ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
+//      CHECK:     tensor.yield %[[CST]] : f32
+//      CHECK:   } : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
+//      CHECK:   return %[[PADDED]] : tensor<32x16x258x258xf32>
\ No newline at end of file

>From c9ef18aba9d528ea6e20ea65654fa8c05f4815aa Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Tue, 22 Apr 2025 19:28:17 +0900
Subject: [PATCH 3/3] nit

---
 mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 4e8af2bf46014..ce00ae4883e88 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -9,8 +9,9 @@
 // This file implements the linalg dialect Fusion on tensors operations pass.
 //
 //===----------------------------------------------------------------------===//
+
 #include "mlir/Dialect/Linalg/Passes.h"
-#include <iostream>
+
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -21,7 +22,6 @@
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
-#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"



More information about the Mlir-commits mailing list