[Mlir-commits] [mlir] c886d66 - [mlir] Add reshape propagation patterns for tensor.pad (#94489)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 7 06:09:10 PDT 2024
Author: Max191
Date: 2024-06-07T09:09:06-04:00
New Revision: c886d66da03bdf26e6fca68a1b730ae6eb923194
URL: https://github.com/llvm/llvm-project/commit/c886d66da03bdf26e6fca68a1b730ae6eb923194
DIFF: https://github.com/llvm/llvm-project/commit/c886d66da03bdf26e6fca68a1b730ae6eb923194.diff
LOG: [mlir] Add reshape propagation patterns for tensor.pad (#94489)
This PR adds fusion by collapsing and fusion by expansion patterns for
`tensor.pad` ops in ElementwiseOpFusion. Pad ops can be expanded or
collapsed as long as none of the padded dimensions will be expanded or
collapsed.
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
mlir/test/Dialect/Linalg/reshape_fusion.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index ad313c2d5ce60..e73df61c96434 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -956,6 +956,69 @@ class FoldWithProducerReshapeOpByExpansion
ControlFusionFn controlFoldingReshapes;
};
+class FoldPadWithProducerReshapeOpByExpansion
+ : public OpRewritePattern<tensor::PadOp> {
+public:
+ FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
+ ControlFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::PadOp>(context, benefit),
+ controlFoldingReshapes(std::move(foldReshapes)) {}
+
+ LogicalResult matchAndRewrite(tensor::PadOp padOp,
+ PatternRewriter &rewriter) const override {
+ tensor::CollapseShapeOp reshapeOp =
+ padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
+ if (!reshapeOp)
+ return failure();
+ if (!reshapeOp->hasOneUse())
+ return failure();
+
+ if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
+ return rewriter.notifyMatchFailure(padOp,
+ "fusion blocked by control function");
+ }
+
+ ArrayRef<int64_t> low = padOp.getStaticLow();
+ ArrayRef<int64_t> high = padOp.getStaticHigh();
+ SmallVector<ReassociationIndices> reassociations =
+ reshapeOp.getReassociationIndices();
+
+ for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
+ if (reInd.size() != 1 && (l != 0 || h != 0))
+ return failure();
+ }
+
+ SmallVector<OpFoldResult> newLow, newHigh;
+ RankedTensorType expandedType = reshapeOp.getSrcType();
+ RankedTensorType paddedType = padOp.getResultType();
+ SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ if (reInd.size() == 1) {
+ expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
+ }
+ for (size_t i = 0; i < reInd.size(); ++i) {
+ newLow.push_back(padOp.getMixedLowPad()[idx]);
+ newHigh.push_back(padOp.getMixedHighPad()[idx]);
+ }
+ }
+
+ Location loc = padOp->getLoc();
+ RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
+ auto newPadOp = rewriter.create<tensor::PadOp>(
+ loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
+ padOp.getConstantPaddingValue(), padOp.getNofold());
+
+ rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+ padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
+
+ return success();
+ }
+
+private:
+ ControlFusionFn controlFoldingReshapes;
+};
+
/// Pattern to fold a tensor.expand_shape op with its producer generic op
/// by expanding the dimensionality of the loop in the producer op.
struct FoldReshapeWithGenericOpByExpansion
@@ -1702,6 +1765,85 @@ class FoldWithProducerReshapeOpByCollapsing
ControlFusionFn controlFoldingReshapes;
};
+class FoldPadWithProducerReshapeOpByCollapsing
+ : public OpRewritePattern<tensor::PadOp> {
+public:
+ FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
+ ControlFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::PadOp>(context, benefit),
+ controlFoldingReshapes(std::move(foldReshapes)) {}
+
+ LogicalResult matchAndRewrite(tensor::PadOp padOp,
+ PatternRewriter &rewriter) const override {
+ tensor::ExpandShapeOp reshapeOp =
+ padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
+ if (!reshapeOp)
+ return failure();
+ if (!reshapeOp->hasOneUse())
+ return failure();
+
+ if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
+ return rewriter.notifyMatchFailure(padOp,
+ "fusion blocked by control function");
+ }
+
+ ArrayRef<int64_t> low = padOp.getStaticLow();
+ ArrayRef<int64_t> high = padOp.getStaticHigh();
+ SmallVector<ReassociationIndices> reassociations =
+ reshapeOp.getReassociationIndices();
+
+ for (auto reInd : reassociations) {
+ if (reInd.size() == 1)
+ continue;
+ if (llvm::any_of(reInd, [&](int64_t ind) {
+ return low[ind] != 0 || high[ind] != 0;
+ })) {
+ return failure();
+ }
+ }
+
+ SmallVector<OpFoldResult> newLow, newHigh;
+ RankedTensorType collapsedType = reshapeOp.getSrcType();
+ RankedTensorType paddedType = padOp.getResultType();
+ SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
+ SmallVector<OpFoldResult> expandedPaddedSizes(
+ getMixedValues(reshapeOp.getStaticOutputShape(),
+ reshapeOp.getOutputShape(), rewriter));
+ AffineExpr d0, d1, d2;
+ bindDims(rewriter.getContext(), d0, d1, d2);
+ auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
+ Location loc = reshapeOp->getLoc();
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
+ OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
+ if (reInd.size() == 1) {
+ collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
+ OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
+ expandedPaddedSizes[reInd[0]] = paddedSize;
+ }
+ newLow.push_back(l);
+ newHigh.push_back(h);
+ }
+
+ RankedTensorType collapsedPaddedType =
+ paddedType.clone(collapsedPaddedShape);
+ auto newPadOp = rewriter.create<tensor::PadOp>(
+ loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
+ padOp.getConstantPaddingValue(), padOp.getNofold());
+
+ rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
+ padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
+ expandedPaddedSizes);
+
+ return success();
+ }
+
+private:
+ ControlFusionFn controlFoldingReshapes;
+};
+
/// Pattern to collapse dimensions.
template <typename LinalgType>
class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
@@ -1937,6 +2079,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
const ControlFusionFn &controlFoldingReshapes) {
patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
+ patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
+ controlFoldingReshapes);
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
}
@@ -1946,6 +2090,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
const ControlFusionFn &controlFoldingReshapes) {
patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
controlFoldingReshapes);
+ patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
+ patterns.getContext(), controlFoldingReshapes);
}
void mlir::linalg::populateElementwiseOpsFusionPatterns(
diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index 0d40df534a3bb..600f0dea31f4a 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -537,3 +537,71 @@ func.func @no_fold_non_consecutive_reduction_dims(%arg0 : tensor<?x?xi32>, %sz0:
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[EXPAND_ARG0]] :
// CHECK: return %[[GENERIC]]
+
+// -----
+
+func.func @fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x3x4x17x6x7x8x14xi32> {
+ %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
+ %cst = arith.constant 0 : i32
+ %padded_0 = tensor.pad %expand low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index,
+ %arg5: index, %arg6: index, %arg7: index, %arg8: index):
+ tensor.yield %cst : i32
+ } : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32>
+ return %padded_0 : tensor<8x3x4x17x6x7x8x14xi32>
+}
+// CHECK: func @fuse_by_collapsing_pad(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>)
+// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
+// CHECK-SAME: low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2]
+// CHECK: tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32>
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
+// CHECK-SAME: output_shape [8, 3, 4, 17, 6, 7, 8, 14] : tensor<8x12x17x336x14xi32> into tensor<8x3x4x17x6x7x8x14xi32>
+// CHECK: return %[[EXPAND]]
+
+// -----
+
+func.func @no_fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x5x4x17x6x7x8x14xi32> {
+ %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
+ %cst = arith.constant 0 : i32
+ %padded_0 = tensor.pad %expand low[1, 2, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index,
+ %arg5: index, %arg6: index, %arg7: index, %arg8: index):
+ tensor.yield %cst : i32
+ } : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x5x4x17x6x7x8x14xi32>
+ return %padded_0 : tensor<8x5x4x17x6x7x8x14xi32>
+}
+// CHECK: func @no_fuse_by_collapsing_pad(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>)
+// CHECK: %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
+// CHECK-SAME: output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
+// CHECK: %[[PAD:.+]] = tensor.pad %[[EXPAND_ARG0]]
+// CHECK-SAME: low[1, 2, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2]
+// CHECK: tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x5x4x17x6x7x8x14xi32>
+// CHECK: return %[[PAD]]
+
+// -----
+
+func.func @fuse_by_collapsing_dynamic_pad(%arg0 : tensor<?x?x?x?xf32>,
+ %s0 : index, %s1 : index, %s2 : index, %s3 : index, %s4 : index, %s5 : index,
+ %l0 : index, %l1 : index, %h0 : index, %h1 : index) -> tensor<?x?x?x?x?x?xf32> {
+ %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5]] output_shape [%s0, %s1, %s2, %s3, %s4, %s5] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
+ %cst = arith.constant 0.0 : f32
+ %padded_0 = tensor.pad %expand low[%l0, 0, 0, %l1, 0, 0] high[%h0, 0, 0, %h1, 0, 0] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index):
+ tensor.yield %cst : f32
+ } : tensor<?x?x?x?x?x?xf32> to tensor<?x?x?x?x?x?xf32>
+ return %padded_0 : tensor<?x?x?x?x?x?xf32>
+}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)>
+// CHECK: func @fuse_by_collapsing_dynamic_pad(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index, %[[S4:.+]]: index, %[[S5:.+]]: index, %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index
+// CHECK: %[[PAD_SIZE0:.+]] = affine.apply #[[MAP]]()[%[[L0]], %[[H0]], %[[S0]]]
+// CHECK: %[[PAD_SIZE1:.+]] = affine.apply #[[MAP]]()[%[[L1]], %[[H1]], %[[S3]]]
+// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
+// CHECK-SAME: low[%[[L0]], 0, %[[L1]], 0] high[%[[H0]], 0, %[[H1]], 0]
+// CHECK: tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
+// CHECK-SAME: output_shape [%[[PAD_SIZE0]], %[[S1]], %[[S2]], %[[PAD_SIZE1]], %[[S4]], %[[S5]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
+// CHECK: return %[[EXPAND]]
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index f42666f81bbad..b8df5fc88e199 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -826,3 +826,64 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
// CHECK-SAME: [0, 1], [2, 3]
// CHECK-SAME: tensor<?x7x?x8xf32> into tensor<?x?xf32>
// CHECK: return %[[T4]]
+
+// -----
+
+func.func @fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> {
+ %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
+ %cst = arith.constant 0 : i32
+ %padded_0 = tensor.pad %collapse low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index):
+ tensor.yield %cst : i32
+ } : tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32>
+ return %padded_0 : tensor<8x12x17x336x14xi32>
+}
+// CHECK: func @fuse_by_expanding_pad(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>)
+// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
+// CHECK-SAME: low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2]
+// CHECK: tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32>
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
+// CHECK-SAME: : tensor<8x3x4x17x6x7x8x14xi32> into tensor<8x12x17x336x14xi32>
+// CHECK: return %[[COLLAPSE]]
+
+// -----
+
+func.func @no_fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x339x14xi32> {
+ %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
+ %cst = arith.constant 0 : i32
+ %padded_0 = tensor.pad %collapse low[1, 0, 8, 0, 3] high[5, 0, 4, 3, 2] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index):
+ tensor.yield %cst : i32
+ } : tensor<2x12x5x336x9xi32> to tensor<8x12x17x339x14xi32>
+ return %padded_0 : tensor<8x12x17x339x14xi32>
+}
+// CHECK: func @no_fuse_by_expanding_pad(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>)
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
+// CHECK-SAME: : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
+// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]]
+// CHECK-SAME: low[1, 0, 8, 0, 3] high[5, 0, 4, 3, 2]
+// CHECK: tensor<2x12x5x336x9xi32> to tensor<8x12x17x339x14xi32>
+// CHECK: return %[[PAD]]
+
+// -----
+
+func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor<?x?x?x?x?x?xi32>, %l0: index, %l1: index, %h0: index, %h1: index) -> tensor<?x?x?x?xi32> {
+ %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5]] : tensor<?x?x?x?x?x?xi32> into tensor<?x?x?x?xi32>
+ %cst = arith.constant 0 : i32
+ %padded_0 = tensor.pad %collapse low[%l0, 0, %l1, 0] high[%h0, 0, %h1, 0] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
+ tensor.yield %cst : i32
+ } : tensor<?x?x?x?xi32> to tensor<?x?x?x?xi32>
+ return %padded_0 : tensor<?x?x?x?xi32>
+}
+// CHECK: func @fuse_by_expanding_dynamic_pad(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?x?x?xi32>
+// CHECK-SAME: %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index
+// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
+// CHECK-SAME: low[%[[L0]], 0, 0, %[[L1]], 0, 0] high[%[[H0]], 0, 0, %[[H1]], 0, 0]
+// CHECK: tensor<?x?x?x?x?x?xi32> to tensor<?x?x?x?x?x?xi32>
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
+// CHECK-SAME: : tensor<?x?x?x?x?x?xi32> into tensor<?x?x?x?xi32>
+// CHECK: return %[[COLLAPSE]]
More information about the Mlir-commits
mailing list