[Mlir-commits] [mlir] [mlir] Add reshape propagation patterns for tensor.pad (PR #94489)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jun 5 08:43:14 PDT 2024
https://github.com/Max191 created https://github.com/llvm/llvm-project/pull/94489
This PR adds fusion by collapsing and fusion by expansion patterns for `tensor.pad` ops in ElementwiseOpFusion. Pad ops can be expanded or collapsed as long as none of the padded dimensions will be expanded or collapsed.
>From a33e68ba5b38b2e8c3b8ff20bc6f998ae8dfdd42 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 30 May 2024 17:56:52 -0400
Subject: [PATCH 1/2] [mlir] Add reshape propagation patterns for tensor.pad
---
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 125 ++++++++++++++++++
1 file changed, 125 insertions(+)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index ad313c2d5ce60..4f0c5835ad823 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
@@ -956,6 +957,64 @@ class FoldWithProducerReshapeOpByExpansion
ControlFusionFn controlFoldingReshapes;
};
+class FoldPadWithProducerReshapeOpByExpansion
+ : public OpRewritePattern<tensor::PadOp> {
+public:
+ FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
+ ControlFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::PadOp>(context, benefit),
+ controlFoldingReshapes(std::move(foldReshapes)) {}
+
+ LogicalResult matchAndRewrite(tensor::PadOp padOp,
+ PatternRewriter &rewriter) const override {
+ tensor::CollapseShapeOp reshapeOp =
+ padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
+ if (!reshapeOp)
+ return failure();
+ if (!reshapeOp->hasOneUse())
+ return failure();
+
+ ArrayRef<int64_t> low = padOp.getStaticLow();
+ ArrayRef<int64_t> high = padOp.getStaticHigh();
+ SmallVector<ReassociationIndices> reassociations =
+ reshapeOp.getReassociationIndices();
+
+ for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
+ if (reInd.size() != 1 && l != 0 && h != 0)
+ return failure();
+ }
+
+ SmallVector<OpFoldResult> newLow, newHigh;
+ RankedTensorType expandedType = reshapeOp.getSrcType();
+ RankedTensorType paddedType = padOp.getResultType();
+ SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ if (reInd.size() == 1) {
+ expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
+ }
+ for (auto ind : reInd) {
+ newLow.push_back(padOp.getMixedLowPad()[idx]);
+ newHigh.push_back(padOp.getMixedHighPad()[idx]);
+ }
+ }
+
+ Location loc = padOp->getLoc();
+ RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
+ auto newPadOp = rewriter.create<tensor::PadOp>(
+ loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
+ padOp.getConstantPaddingValue(), padOp.getNofold());
+
+ rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+ padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
+
+ return success();
+ }
+
+private:
+ ControlFusionFn controlFoldingReshapes;
+};
+
/// Pattern to fold a tensor.expand_shape op with its producer generic op
/// by expanding the dimensionality of the loop in the producer op.
struct FoldReshapeWithGenericOpByExpansion
@@ -1702,6 +1761,68 @@ class FoldWithProducerReshapeOpByCollapsing
ControlFusionFn controlFoldingReshapes;
};
+class FoldPadWithProducerReshapeOpByCollapsing
+ : public OpRewritePattern<tensor::PadOp> {
+public:
+ FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
+ ControlFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::PadOp>(context, benefit),
+ controlFoldingReshapes(std::move(foldReshapes)) {}
+
+ LogicalResult matchAndRewrite(tensor::PadOp padOp,
+ PatternRewriter &rewriter) const override {
+ tensor::ExpandShapeOp reshapeOp =
+ padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
+ if (!reshapeOp)
+ return failure();
+ if (!reshapeOp->hasOneUse())
+ return failure();
+
+ ArrayRef<int64_t> low = padOp.getStaticLow();
+ ArrayRef<int64_t> high = padOp.getStaticHigh();
+ SmallVector<ReassociationIndices> reassociations =
+ reshapeOp.getReassociationIndices();
+
+ for (auto reInd : reassociations) {
+ if (reInd.size() == 1)
+ continue;
+ if (llvm::any_of(reInd, [&](int64_t ind) {
+ return low[ind] != 0 || high[ind] != 0;
+ })) {
+ return failure();
+ }
+ }
+
+ SmallVector<OpFoldResult> newLow, newHigh;
+ RankedTensorType collapsedType = reshapeOp.getSrcType();
+ RankedTensorType paddedType = padOp.getResultType();
+ SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ if (reInd.size() == 1) {
+ collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
+ }
+ newLow.push_back(padOp.getMixedLowPad()[reInd[0]]);
+ newHigh.push_back(padOp.getMixedHighPad()[reInd[0]]);
+ }
+
+ Location loc = padOp->getLoc();
+ RankedTensorType collapsedPaddedType =
+ paddedType.clone(collapsedPaddedShape);
+ auto newPadOp = rewriter.create<tensor::PadOp>(
+ loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
+ padOp.getConstantPaddingValue(), padOp.getNofold());
+
+ rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
+ padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
+
+ return success();
+ }
+
+private:
+ ControlFusionFn controlFoldingReshapes;
+};
+
/// Pattern to collapse dimensions.
template <typename LinalgType>
class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
@@ -1937,6 +2058,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
const ControlFusionFn &controlFoldingReshapes) {
patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
+ // patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
+ // controlFoldingReshapes);
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
}
@@ -1946,6 +2069,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
const ControlFusionFn &controlFoldingReshapes) {
patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
controlFoldingReshapes);
+ // patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
+ // patterns.getContext(), controlFoldingReshapes);
}
void mlir::linalg::populateElementwiseOpsFusionPatterns(
>From 1735c4972a1655ecabe0cd178d9932580170b922 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Wed, 5 Jun 2024 09:54:51 -0400
Subject: [PATCH 2/2] add tests, support dynamic expand
---
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 33 ++++++---
.../fuse-with-reshape-by-collapsing.mlir | 68 +++++++++++++++++++
mlir/test/Dialect/Linalg/reshape_fusion.mlir | 61 +++++++++++++++++
3 files changed, 151 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 4f0c5835ad823..d93ef9138c474 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -16,7 +16,6 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
@@ -981,7 +980,7 @@ class FoldPadWithProducerReshapeOpByExpansion
reshapeOp.getReassociationIndices();
for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
- if (reInd.size() != 1 && l != 0 && h != 0)
+ if (reInd.size() != 1 && (l != 0 || h != 0))
return failure();
}
@@ -993,7 +992,7 @@ class FoldPadWithProducerReshapeOpByExpansion
if (reInd.size() == 1) {
expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
}
- for (auto ind : reInd) {
+ for (size_t i = 0; i < reInd.size(); ++i) {
newLow.push_back(padOp.getMixedLowPad()[idx]);
newHigh.push_back(padOp.getMixedHighPad()[idx]);
}
@@ -1798,15 +1797,26 @@ class FoldPadWithProducerReshapeOpByCollapsing
RankedTensorType collapsedType = reshapeOp.getSrcType();
RankedTensorType paddedType = padOp.getResultType();
SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
+ SmallVector<OpFoldResult> expandedPaddedSizes(
+ getMixedValues(reshapeOp.getStaticOutputShape(),
+ reshapeOp.getOutputShape(), rewriter));
+ AffineExpr d0, d1, d2;
+ bindDims(rewriter.getContext(), d0, d1, d2);
+ auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
+ Location loc = reshapeOp->getLoc();
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
+ OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
if (reInd.size() == 1) {
collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
+ OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
+ expandedPaddedSizes[reInd[0]] = paddedSize;
}
- newLow.push_back(padOp.getMixedLowPad()[reInd[0]]);
- newHigh.push_back(padOp.getMixedHighPad()[reInd[0]]);
+ newLow.push_back(l);
+ newHigh.push_back(h);
}
- Location loc = padOp->getLoc();
RankedTensorType collapsedPaddedType =
paddedType.clone(collapsedPaddedShape);
auto newPadOp = rewriter.create<tensor::PadOp>(
@@ -1814,7 +1824,8 @@ class FoldPadWithProducerReshapeOpByCollapsing
padOp.getConstantPaddingValue(), padOp.getNofold());
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
- padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
+ padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
+ expandedPaddedSizes);
return success();
}
@@ -2058,8 +2069,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
const ControlFusionFn &controlFoldingReshapes) {
patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
- // patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
- // controlFoldingReshapes);
+ patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
+ controlFoldingReshapes);
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
}
@@ -2069,8 +2080,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
const ControlFusionFn &controlFoldingReshapes) {
patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
controlFoldingReshapes);
- // patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
- // patterns.getContext(), controlFoldingReshapes);
+ patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
+ patterns.getContext(), controlFoldingReshapes);
}
void mlir::linalg::populateElementwiseOpsFusionPatterns(
diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index 0d40df534a3bb..600f0dea31f4a 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -537,3 +537,71 @@ func.func @no_fold_non_consecutive_reduction_dims(%arg0 : tensor<?x?xi32>, %sz0:
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[EXPAND_ARG0]] :
// CHECK: return %[[GENERIC]]
+
+// -----
+
+func.func @fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x3x4x17x6x7x8x14xi32> {
+ %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
+ %cst = arith.constant 0 : i32
+ %padded_0 = tensor.pad %expand low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index,
+ %arg5: index, %arg6: index, %arg7: index, %arg8: index):
+ tensor.yield %cst : i32
+ } : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32>
+ return %padded_0 : tensor<8x3x4x17x6x7x8x14xi32>
+}
+// CHECK: func @fuse_by_collapsing_pad(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>)
+// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
+// CHECK-SAME: low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2]
+// CHECK: tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32>
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
+// CHECK-SAME: output_shape [8, 3, 4, 17, 6, 7, 8, 14] : tensor<8x12x17x336x14xi32> into tensor<8x3x4x17x6x7x8x14xi32>
+// CHECK: return %[[EXPAND]]
+
+// -----
+
+func.func @no_fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x5x4x17x6x7x8x14xi32> {
+ %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
+ %cst = arith.constant 0 : i32
+ %padded_0 = tensor.pad %expand low[1, 2, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index,
+ %arg5: index, %arg6: index, %arg7: index, %arg8: index):
+ tensor.yield %cst : i32
+ } : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x5x4x17x6x7x8x14xi32>
+ return %padded_0 : tensor<8x5x4x17x6x7x8x14xi32>
+}
+// CHECK: func @no_fuse_by_collapsing_pad(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>)
+// CHECK: %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
+// CHECK-SAME: output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
+// CHECK: %[[PAD:.+]] = tensor.pad %[[EXPAND_ARG0]]
+// CHECK-SAME: low[1, 2, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2]
+// CHECK: tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x5x4x17x6x7x8x14xi32>
+// CHECK: return %[[PAD]]
+
+// -----
+
+func.func @fuse_by_collapsing_dynamic_pad(%arg0 : tensor<?x?x?x?xf32>,
+ %s0 : index, %s1 : index, %s2 : index, %s3 : index, %s4 : index, %s5 : index,
+ %l0 : index, %l1 : index, %h0 : index, %h1 : index) -> tensor<?x?x?x?x?x?xf32> {
+ %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5]] output_shape [%s0, %s1, %s2, %s3, %s4, %s5] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
+ %cst = arith.constant 0.0 : f32
+ %padded_0 = tensor.pad %expand low[%l0, 0, 0, %l1, 0, 0] high[%h0, 0, 0, %h1, 0, 0] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index):
+ tensor.yield %cst : f32
+ } : tensor<?x?x?x?x?x?xf32> to tensor<?x?x?x?x?x?xf32>
+ return %padded_0 : tensor<?x?x?x?x?x?xf32>
+}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)>
+// CHECK: func @fuse_by_collapsing_dynamic_pad(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index, %[[S4:.+]]: index, %[[S5:.+]]: index, %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index
+// CHECK: %[[PAD_SIZE0:.+]] = affine.apply #[[MAP]]()[%[[L0]], %[[H0]], %[[S0]]]
+// CHECK: %[[PAD_SIZE1:.+]] = affine.apply #[[MAP]]()[%[[L1]], %[[H1]], %[[S3]]]
+// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
+// CHECK-SAME: low[%[[L0]], 0, %[[L1]], 0] high[%[[H0]], 0, %[[H1]], 0]
+// CHECK: tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
+// CHECK-SAME: output_shape [%[[PAD_SIZE0]], %[[S1]], %[[S2]], %[[PAD_SIZE1]], %[[S4]], %[[S5]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
+// CHECK: return %[[EXPAND]]
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index f42666f81bbad..b8df5fc88e199 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -826,3 +826,64 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
// CHECK-SAME: [0, 1], [2, 3]
// CHECK-SAME: tensor<?x7x?x8xf32> into tensor<?x?xf32>
// CHECK: return %[[T4]]
+
+// -----
+
+func.func @fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> {
+ %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
+ %cst = arith.constant 0 : i32
+ %padded_0 = tensor.pad %collapse low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index):
+ tensor.yield %cst : i32
+ } : tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32>
+ return %padded_0 : tensor<8x12x17x336x14xi32>
+}
+// CHECK: func @fuse_by_expanding_pad(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>)
+// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
+// CHECK-SAME: low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2]
+// CHECK: tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32>
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
+// CHECK-SAME: : tensor<8x3x4x17x6x7x8x14xi32> into tensor<8x12x17x336x14xi32>
+// CHECK: return %[[COLLAPSE]]
+
+// -----
+
+func.func @no_fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x339x14xi32> {
+ %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
+ %cst = arith.constant 0 : i32
+ %padded_0 = tensor.pad %collapse low[1, 0, 8, 0, 3] high[5, 0, 4, 3, 2] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index):
+ tensor.yield %cst : i32
+ } : tensor<2x12x5x336x9xi32> to tensor<8x12x17x339x14xi32>
+ return %padded_0 : tensor<8x12x17x339x14xi32>
+}
+// CHECK: func @no_fuse_by_expanding_pad(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>)
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
+// CHECK-SAME: : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
+// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]]
+// CHECK-SAME: low[1, 0, 8, 0, 3] high[5, 0, 4, 3, 2]
+// CHECK: tensor<2x12x5x336x9xi32> to tensor<8x12x17x339x14xi32>
+// CHECK: return %[[PAD]]
+
+// -----
+
+func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor<?x?x?x?x?x?xi32>, %l0: index, %l1: index, %h0: index, %h1: index) -> tensor<?x?x?x?xi32> {
+ %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5]] : tensor<?x?x?x?x?x?xi32> into tensor<?x?x?x?xi32>
+ %cst = arith.constant 0 : i32
+ %padded_0 = tensor.pad %collapse low[%l0, 0, %l1, 0] high[%h0, 0, %h1, 0] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
+ tensor.yield %cst : i32
+ } : tensor<?x?x?x?xi32> to tensor<?x?x?x?xi32>
+ return %padded_0 : tensor<?x?x?x?xi32>
+}
+// CHECK: func @fuse_by_expanding_dynamic_pad(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?x?x?xi32>
+// CHECK-SAME: %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index
+// CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
+// CHECK-SAME: low[%[[L0]], 0, 0, %[[L1]], 0, 0] high[%[[H0]], 0, 0, %[[H1]], 0, 0]
+// CHECK: tensor<?x?x?x?x?x?xi32> to tensor<?x?x?x?x?x?xi32>
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
+// CHECK-SAME: : tensor<?x?x?x?x?x?xi32> into tensor<?x?x?x?xi32>
+// CHECK: return %[[COLLAPSE]]
More information about the Mlir-commits
mailing list