[Mlir-commits] [mlir] [Draft][MLIR] Add reshape propagation through tensor.pad (PR #136681)
Hyunsung Lee
llvmlistbot at llvm.org
Tue Apr 22 03:35:55 PDT 2025
https://github.com/ita9naiwa updated https://github.com/llvm/llvm-project/pull/136681
>From 1b4634142b861a60dfae0c4e16e2caffb1d150a0 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Tue, 22 Apr 2025 19:04:02 +0900
Subject: [PATCH 1/4] Add FoldReshapeWithProducerPadOpByExpansion
---
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 83 ++++++++++++++++++-
1 file changed, 82 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index bf70597d5ddfe..4e8af2bf46014 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -9,8 +9,8 @@
// This file implements the linalg dialect Fusion on tensors operations pass.
//
//===----------------------------------------------------------------------===//
-
#include "mlir/Dialect/Linalg/Passes.h"
+#include <iostream>
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -21,6 +21,7 @@
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
@@ -1101,6 +1102,84 @@ class FoldPadWithProducerReshapeOpByExpansion
ControlFusionFn controlFoldingReshapes;
};
+/// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
+/// by bubbling the expand_shape before the pad.
+struct FoldReshapeWithProducerPadOpByExpansion
+ : public OpRewritePattern<tensor::ExpandShapeOp> {
+
+ FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
+ ControlFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+ controlFoldingReshapes(std::move(foldReshapes)) {}
+
+ LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+ PatternRewriter &rewriter) const override {
+ tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
+ if (!padOp)
+ return failure();
+
+ if (!padOp->hasOneUse())
+ return failure();
+
+ if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
+ return rewriter.notifyMatchFailure(expandOp,
+ "fusion blocked by control function");
+ }
+
+ // return failure if padOp has *any* dynamic padding
+ if (!padOp.getLow().empty() || !padOp.getHigh().empty()) {
+ return failure();
+ }
+
+ SmallVector<ReassociationIndices> reassociations =
+ expandOp.getReassociationIndices();
+ ArrayRef<int64_t> low = padOp.getStaticLow();
+ ArrayRef<int64_t> high = padOp.getStaticHigh();
+
+ for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
+ if (reInd.size() != 1 && (l != 0 || h != 0))
+ return failure();
+ }
+
+ SmallVector<OpFoldResult> newLow, newHigh;
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ for (size_t i = 0; i < reInd.size(); ++i) {
+ newLow.push_back(padOp.getMixedLowPad()[idx]);
+ newHigh.push_back(padOp.getMixedHighPad()[idx]);
+ }
+ }
+
+ // Calculate expanded shape manually
+ auto reshapeType = cast<RankedTensorType>(expandOp.getType());
+ ArrayRef<int64_t> finalShape = reshapeType.getShape();
+ SmallVector<int64_t> expandedShape;
+ for (auto [inDimIdx, outGroup] : llvm::enumerate(reassociations)) {
+ for (auto outDimIdx : outGroup) {
+ int64_t sz = finalShape[outDimIdx] - low[inDimIdx] - high[inDimIdx];
+ expandedShape.push_back(sz);
+ }
+ }
+
+ // Apply the reshape to the pad's source first
+ Location loc = expandOp.getLoc();
+ Value expandedSrc = rewriter.create<tensor::ExpandShapeOp>(
+ loc,
+ RankedTensorType::get(expandedShape,
+ padOp.getSourceType().getElementType()),
+ padOp.getSource(), reassociations);
+
+ auto newPadOp = rewriter.create<tensor::PadOp>(
+ loc, expandOp.getType(), expandedSrc, newLow, newHigh,
+ padOp.getConstantPaddingValue(), padOp.getNofold());
+ rewriter.replaceOp(expandOp, newPadOp.getResult());
+ return success();
+ }
+
+private:
+ ControlFusionFn controlFoldingReshapes;
+};
+
/// Pattern to fold a tensor.expand_shape op with its producer generic op
/// by expanding the dimensionality of the loop in the producer op.
struct FoldReshapeWithGenericOpByExpansion
@@ -2249,6 +2328,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
+ patterns.add<FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext(),
+ controlFoldingReshapes);
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
}
>From 333c4e0e847082937eed602c6503b9d1e27fe844 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Tue, 22 Apr 2025 19:04:14 +0900
Subject: [PATCH 2/4] Add test
---
mlir/test/Dialect/Linalg/reshape_fusion.mlir | 27 ++++++++++++++++++--
1 file changed, 25 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 67b4f2b32bad5..d6eafce6d049f 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -247,7 +247,7 @@ func.func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
#map0 = affine_map<(d0, d1) -> (d0, d1)>
func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
- %arg1 : tensor<?x?xi32>,
+ %arg1 : tensor<?x?xi32>,
%sz0: index, %sz1: index) ->
tensor<?x?x4x5xi32>
{
@@ -515,7 +515,7 @@ func.func @fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
// -----
func.func @reshape_as_consumer_permutation_with_multiple_results
- (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index,
+ (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index,
%sz1: index, %sz2: index, %sz3: index, %sz4: index)
-> (tensor<?x2x?x3x4x?xf32>, tensor<?x?x2x3x4x?xf32>) {
%c:2 = linalg.generic {
@@ -893,3 +893,26 @@ func.func @move_operand_deps(%arg0 : tensor<?x128xf16>,
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[EXPANDED]] :
// CHECK: return %[[GENERIC]]
+
+// -----
+
+func.func @fold_tensor_pad_with_expand(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
+ %padded = tensor.pad %0 low[0, 1, 1] high[0, 1, 1] {
+ ^bb0(%i: index, %j: index, %k: index):
+ tensor.yield %cst : f32
+ } : tensor<512x256x256xf32> to tensor<512x258x258xf32>
+ %expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32>
+ return %expanded : tensor<32x16x258x258xf32>
+}
+// CHECK: func @fold_tensor_pad_with_expand(
+// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<512x256x256xf32>
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXPANDED]] : tensor<32x16x256x256xf32>)
+// CHECK: %[[PADDED:.*]] = tensor.pad %[[FILLED]] low[0, 0, 1, 1] high[0, 0, 1, 1]
+// CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
+// CHECK: tensor.yield %[[CST]] : f32
+// CHECK: } : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
+// CHECK: return %[[PADDED]] : tensor<32x16x258x258xf32>
\ No newline at end of file
>From 321aea7396f017257407f0ceb9b3de41cebb5d30 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Tue, 22 Apr 2025 19:28:17 +0900
Subject: [PATCH 3/4] nit
---
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 4e8af2bf46014..ce00ae4883e88 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -9,8 +9,9 @@
// This file implements the linalg dialect Fusion on tensors operations pass.
//
//===----------------------------------------------------------------------===//
+
#include "mlir/Dialect/Linalg/Passes.h"
-#include <iostream>
+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -21,7 +22,6 @@
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
>From 4ab75e14a4c7cba8609d0445e1a445d671f2041e Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Tue, 22 Apr 2025 19:35:44 +0900
Subject: [PATCH 4/4] nit
---
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index ce00ae4883e88..b87792c1aa5f3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -12,7 +12,6 @@
#include "mlir/Dialect/Linalg/Passes.h"
-
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
More information about the Mlir-commits
mailing list