[Mlir-commits] [mlir] [mlir] Add reshape propagation patterns for tensor.pad (PR #94489)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 5 08:43:14 PDT 2024


https://github.com/Max191 created https://github.com/llvm/llvm-project/pull/94489

This PR adds fusion by collapsing and fusion by expansion patterns for `tensor.pad` ops in ElementwiseOpFusion. Pad ops can be expanded or collapsed as long as none of the padded dimensions will be expanded or collapsed.

>From a33e68ba5b38b2e8c3b8ff20bc6f998ae8dfdd42 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 30 May 2024 17:56:52 -0400
Subject: [PATCH 1/2] [mlir] Add reshape propagation patterns for tensor.pad

---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 125 ++++++++++++++++++
 1 file changed, 125 insertions(+)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index ad313c2d5ce60..4f0c5835ad823 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
@@ -956,6 +957,64 @@ class FoldWithProducerReshapeOpByExpansion
   ControlFusionFn controlFoldingReshapes;
 };
 
+class FoldPadWithProducerReshapeOpByExpansion
+    : public OpRewritePattern<tensor::PadOp> {
+public:
+  FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context,
+                                          ControlFusionFn foldReshapes,
+                                          PatternBenefit benefit = 1)
+      : OpRewritePattern<tensor::PadOp>(context, benefit),
+        controlFoldingReshapes(std::move(foldReshapes)) {}
+
+  LogicalResult matchAndRewrite(tensor::PadOp padOp,
+                                PatternRewriter &rewriter) const override {
+    tensor::CollapseShapeOp reshapeOp =
+        padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
+    if (!reshapeOp)
+      return failure();
+    if (!reshapeOp->hasOneUse())
+      return failure();
+
+    ArrayRef<int64_t> low = padOp.getStaticLow();
+    ArrayRef<int64_t> high = padOp.getStaticHigh();
+    SmallVector<ReassociationIndices> reassociations =
+        reshapeOp.getReassociationIndices();
+
+    for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
+      if (reInd.size() != 1 && l != 0 && h != 0)
+        return failure();
+    }
+
+    SmallVector<OpFoldResult> newLow, newHigh;
+    RankedTensorType expandedType = reshapeOp.getSrcType();
+    RankedTensorType paddedType = padOp.getResultType();
+    SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
+    for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+      if (reInd.size() == 1) {
+        expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
+      }
+      for (auto ind : reInd) {
+        newLow.push_back(padOp.getMixedLowPad()[idx]);
+        newHigh.push_back(padOp.getMixedHighPad()[idx]);
+      }
+    }
+
+    Location loc = padOp->getLoc();
+    RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
+    auto newPadOp = rewriter.create<tensor::PadOp>(
+        loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
+        padOp.getConstantPaddingValue(), padOp.getNofold());
+
+    rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+        padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
+
+    return success();
+  }
+
+private:
+  ControlFusionFn controlFoldingReshapes;
+};
+
 /// Pattern to fold a tensor.expand_shape op with its producer generic op
 /// by expanding the dimensionality of the loop in the producer op.
 struct FoldReshapeWithGenericOpByExpansion
@@ -1702,6 +1761,68 @@ class FoldWithProducerReshapeOpByCollapsing
   ControlFusionFn controlFoldingReshapes;
 };
 
+class FoldPadWithProducerReshapeOpByCollapsing
+    : public OpRewritePattern<tensor::PadOp> {
+public:
+  FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
+                                           ControlFusionFn foldReshapes,
+                                           PatternBenefit benefit = 1)
+      : OpRewritePattern<tensor::PadOp>(context, benefit),
+        controlFoldingReshapes(std::move(foldReshapes)) {}
+
+  LogicalResult matchAndRewrite(tensor::PadOp padOp,
+                                PatternRewriter &rewriter) const override {
+    tensor::ExpandShapeOp reshapeOp =
+        padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
+    if (!reshapeOp)
+      return failure();
+    if (!reshapeOp->hasOneUse())
+      return failure();
+
+    ArrayRef<int64_t> low = padOp.getStaticLow();
+    ArrayRef<int64_t> high = padOp.getStaticHigh();
+    SmallVector<ReassociationIndices> reassociations =
+        reshapeOp.getReassociationIndices();
+
+    for (auto reInd : reassociations) {
+      if (reInd.size() == 1)
+        continue;
+      if (llvm::any_of(reInd, [&](int64_t ind) {
+            return low[ind] != 0 || high[ind] != 0;
+          })) {
+        return failure();
+      }
+    }
+
+    SmallVector<OpFoldResult> newLow, newHigh;
+    RankedTensorType collapsedType = reshapeOp.getSrcType();
+    RankedTensorType paddedType = padOp.getResultType();
+    SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
+    for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+      if (reInd.size() == 1) {
+        collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
+      }
+      newLow.push_back(padOp.getMixedLowPad()[reInd[0]]);
+      newHigh.push_back(padOp.getMixedHighPad()[reInd[0]]);
+    }
+
+    Location loc = padOp->getLoc();
+    RankedTensorType collapsedPaddedType =
+        paddedType.clone(collapsedPaddedShape);
+    auto newPadOp = rewriter.create<tensor::PadOp>(
+        loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
+        padOp.getConstantPaddingValue(), padOp.getNofold());
+
+    rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
+        padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
+
+    return success();
+  }
+
+private:
+  ControlFusionFn controlFoldingReshapes;
+};
+
 /// Pattern to collapse dimensions.
 template <typename LinalgType>
 class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
@@ -1937,6 +2058,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
     const ControlFusionFn &controlFoldingReshapes) {
   patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
                                                     controlFoldingReshapes);
+  // patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
+  //                                                       controlFoldingReshapes);
   patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
                                                      controlFoldingReshapes);
 }
@@ -1946,6 +2069,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
     const ControlFusionFn &controlFoldingReshapes) {
   patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
                                                       controlFoldingReshapes);
+  // patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
+  //     patterns.getContext(), controlFoldingReshapes);
 }
 
 void mlir::linalg::populateElementwiseOpsFusionPatterns(

>From 1735c4972a1655ecabe0cd178d9932580170b922 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Wed, 5 Jun 2024 09:54:51 -0400
Subject: [PATCH 2/2] add tests, support dynamic expand

---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 33 ++++++---
 .../fuse-with-reshape-by-collapsing.mlir      | 68 +++++++++++++++++++
 mlir/test/Dialect/Linalg/reshape_fusion.mlir  | 61 +++++++++++++++++
 3 files changed, 151 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 4f0c5835ad823..d93ef9138c474 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -16,7 +16,6 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-#include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
@@ -981,7 +980,7 @@ class FoldPadWithProducerReshapeOpByExpansion
         reshapeOp.getReassociationIndices();
 
     for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
-      if (reInd.size() != 1 && l != 0 && h != 0)
+      if (reInd.size() != 1 && (l != 0 || h != 0))
         return failure();
     }
 
@@ -993,7 +992,7 @@ class FoldPadWithProducerReshapeOpByExpansion
       if (reInd.size() == 1) {
         expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
       }
-      for (auto ind : reInd) {
+      for (size_t i = 0; i < reInd.size(); ++i) {
         newLow.push_back(padOp.getMixedLowPad()[idx]);
         newHigh.push_back(padOp.getMixedHighPad()[idx]);
       }
@@ -1798,15 +1797,26 @@ class FoldPadWithProducerReshapeOpByCollapsing
     RankedTensorType collapsedType = reshapeOp.getSrcType();
     RankedTensorType paddedType = padOp.getResultType();
     SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
+    SmallVector<OpFoldResult> expandedPaddedSizes(
+        getMixedValues(reshapeOp.getStaticOutputShape(),
+                       reshapeOp.getOutputShape(), rewriter));
+    AffineExpr d0, d1, d2;
+    bindDims(rewriter.getContext(), d0, d1, d2);
+    auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
+    Location loc = reshapeOp->getLoc();
     for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+      OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
+      OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
       if (reInd.size() == 1) {
         collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
+        OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply(
+            rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
+        expandedPaddedSizes[reInd[0]] = paddedSize;
       }
-      newLow.push_back(padOp.getMixedLowPad()[reInd[0]]);
-      newHigh.push_back(padOp.getMixedHighPad()[reInd[0]]);
+      newLow.push_back(l);
+      newHigh.push_back(h);
     }
 
-    Location loc = padOp->getLoc();
     RankedTensorType collapsedPaddedType =
         paddedType.clone(collapsedPaddedShape);
     auto newPadOp = rewriter.create<tensor::PadOp>(
@@ -1814,7 +1824,8 @@ class FoldPadWithProducerReshapeOpByCollapsing
         padOp.getConstantPaddingValue(), padOp.getNofold());
 
     rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
-        padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
+        padOp, padOp.getResultType(), newPadOp.getResult(), reassociations,
+        expandedPaddedSizes);
 
     return success();
   }
@@ -2058,8 +2069,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
     const ControlFusionFn &controlFoldingReshapes) {
   patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
                                                     controlFoldingReshapes);
-  // patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
-  //                                                       controlFoldingReshapes);
+  patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
+                                                        controlFoldingReshapes);
   patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
                                                      controlFoldingReshapes);
 }
@@ -2069,8 +2080,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
     const ControlFusionFn &controlFoldingReshapes) {
   patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
                                                       controlFoldingReshapes);
-  // patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
-  //     patterns.getContext(), controlFoldingReshapes);
+  patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
+      patterns.getContext(), controlFoldingReshapes);
 }
 
 void mlir::linalg::populateElementwiseOpsFusionPatterns(
diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index 0d40df534a3bb..600f0dea31f4a 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -537,3 +537,71 @@ func.func @no_fold_non_consecutive_reduction_dims(%arg0 : tensor<?x?xi32>, %sz0:
 //      CHECK:   %[[GENERIC:.+]] = linalg.generic
 // CHECK-SAME:       ins(%[[EXPAND_ARG0]] :
 //      CHECK:   return %[[GENERIC]]
+
+// -----
+
+func.func @fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x3x4x17x6x7x8x14xi32> {
+  %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
+  %cst = arith.constant 0 : i32
+  %padded_0 = tensor.pad %expand low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] {
+  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index,
+       %arg5: index, %arg6: index, %arg7: index, %arg8: index):
+    tensor.yield %cst : i32
+  } : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32>
+  return %padded_0 : tensor<8x3x4x17x6x7x8x14xi32>
+}
+//      CHECK: func @fuse_by_collapsing_pad(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>)
+//      CHECK:   %[[PAD:.+]] = tensor.pad %[[ARG0]]
+// CHECK-SAME:       low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2]
+//      CHECK:       tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32>
+//      CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
+// CHECK-SAME:       output_shape [8, 3, 4, 17, 6, 7, 8, 14] : tensor<8x12x17x336x14xi32> into tensor<8x3x4x17x6x7x8x14xi32>
+//      CHECK:   return %[[EXPAND]]
+
+// -----
+
+func.func @no_fuse_by_collapsing_pad(%arg0 : tensor<2x12x5x336x9xi32>) -> tensor<8x5x4x17x6x7x8x14xi32> {
+  %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
+  %cst = arith.constant 0 : i32
+  %padded_0 = tensor.pad %expand low[1, 2, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] {
+  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index,
+       %arg5: index, %arg6: index, %arg7: index, %arg8: index):
+    tensor.yield %cst : i32
+  } : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x5x4x17x6x7x8x14xi32>
+  return %padded_0 : tensor<8x5x4x17x6x7x8x14xi32>
+}
+//      CHECK: func @no_fuse_by_collapsing_pad(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>)
+//      CHECK:   %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
+// CHECK-SAME:       output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
+//      CHECK:   %[[PAD:.+]] = tensor.pad %[[EXPAND_ARG0]]
+// CHECK-SAME:       low[1, 2, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2]
+//      CHECK:       tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x5x4x17x6x7x8x14xi32>
+//      CHECK:   return %[[PAD]]
+
+// -----
+
+func.func @fuse_by_collapsing_dynamic_pad(%arg0 : tensor<?x?x?x?xf32>,
+    %s0 : index, %s1 : index, %s2 : index, %s3 : index, %s4 : index, %s5 : index,
+    %l0 : index, %l1 : index, %h0 : index, %h1 : index) -> tensor<?x?x?x?x?x?xf32> {
+  %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5]] output_shape [%s0, %s1, %s2, %s3, %s4, %s5] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
+  %cst = arith.constant 0.0 : f32
+  %padded_0 = tensor.pad %expand low[%l0, 0, 0, %l1, 0, 0] high[%h0, 0, 0, %h1, 0, 0] {
+  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index):
+    tensor.yield %cst : f32
+  } : tensor<?x?x?x?x?x?xf32> to tensor<?x?x?x?x?x?xf32>
+  return %padded_0 : tensor<?x?x?x?x?x?xf32>
+}
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)>
+//      CHECK: func @fuse_by_collapsing_dynamic_pad(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME:   %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index, %[[S4:.+]]: index, %[[S5:.+]]: index, %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index
+//      CHECK:   %[[PAD_SIZE0:.+]] = affine.apply #[[MAP]]()[%[[L0]], %[[H0]], %[[S0]]]
+//      CHECK:   %[[PAD_SIZE1:.+]] = affine.apply #[[MAP]]()[%[[L1]], %[[H1]], %[[S3]]]
+//      CHECK:   %[[PAD:.+]] = tensor.pad %[[ARG0]]
+// CHECK-SAME:       low[%[[L0]], 0, %[[L1]], 0] high[%[[H0]], 0, %[[H1]], 0]
+//      CHECK:       tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
+//      CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
+// CHECK-SAME:       output_shape [%[[PAD_SIZE0]], %[[S1]], %[[S2]], %[[PAD_SIZE1]], %[[S4]], %[[S5]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
+//      CHECK:   return %[[EXPAND]]
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index f42666f81bbad..b8df5fc88e199 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -826,3 +826,64 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
 // CHECK-SAME:     [0, 1], [2, 3]
 // CHECK-SAME:     tensor<?x7x?x8xf32> into tensor<?x?xf32>
 //      CHECK:   return %[[T4]]
+
+// -----
+
+func.func @fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> {
+  %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
+  %cst = arith.constant 0 : i32
+  %padded_0 = tensor.pad %collapse low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2] {
+  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index):
+    tensor.yield %cst : i32
+  } : tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32>
+  return %padded_0 : tensor<8x12x17x336x14xi32>
+}
+//      CHECK: func @fuse_by_expanding_pad(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>)
+//      CHECK:   %[[PAD:.+]] = tensor.pad %[[ARG0]]
+// CHECK-SAME:       low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2]
+//      CHECK:       tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32>
+//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
+// CHECK-SAME:       : tensor<8x3x4x17x6x7x8x14xi32> into tensor<8x12x17x336x14xi32>
+//      CHECK:   return %[[COLLAPSE]]
+
+// -----
+
+func.func @no_fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x339x14xi32> {
+  %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
+  %cst = arith.constant 0 : i32
+  %padded_0 = tensor.pad %collapse low[1, 0, 8, 0, 3] high[5, 0, 4, 3, 2] {
+  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index):
+    tensor.yield %cst : i32
+  } : tensor<2x12x5x336x9xi32> to tensor<8x12x17x339x14xi32>
+  return %padded_0 : tensor<8x12x17x339x14xi32>
+}
+//      CHECK: func @no_fuse_by_expanding_pad(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>)
+//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
+// CHECK-SAME:       : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
+//      CHECK:   %[[PAD:.+]] = tensor.pad %[[COLLAPSE]]
+// CHECK-SAME:       low[1, 0, 8, 0, 3] high[5, 0, 4, 3, 2]
+//      CHECK:       tensor<2x12x5x336x9xi32> to tensor<8x12x17x339x14xi32>
+//      CHECK:   return %[[PAD]]
+
+// -----
+
+func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor<?x?x?x?x?x?xi32>, %l0: index, %l1: index, %h0: index, %h1: index) -> tensor<?x?x?x?xi32> {
+  %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5]] : tensor<?x?x?x?x?x?xi32> into tensor<?x?x?x?xi32>
+  %cst = arith.constant 0 : i32
+  %padded_0 = tensor.pad %collapse low[%l0, 0, %l1, 0] high[%h0, 0, %h1, 0] {
+  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
+    tensor.yield %cst : i32
+  } : tensor<?x?x?x?xi32> to tensor<?x?x?x?xi32>
+  return %padded_0 : tensor<?x?x?x?xi32>
+}
+//      CHECK: func @fuse_by_expanding_dynamic_pad(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?x?x?x?xi32>
+// CHECK-SAME:   %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index
+//      CHECK:   %[[PAD:.+]] = tensor.pad %[[ARG0]]
+// CHECK-SAME:       low[%[[L0]], 0, 0, %[[L1]], 0, 0] high[%[[H0]], 0, 0, %[[H1]], 0, 0]
+//      CHECK:       tensor<?x?x?x?x?x?xi32> to tensor<?x?x?x?x?x?xi32>
+//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
+// CHECK-SAME:       : tensor<?x?x?x?x?x?xi32> into tensor<?x?x?x?xi32>
+//      CHECK:   return %[[COLLAPSE]]



More information about the Mlir-commits mailing list