[Mlir-commits] [mlir] [mlir][tensor] Add canonicalization to fold consecutive tensor.pad ops (PR #107302)

Quinn Dawkins llvmlistbot at llvm.org
Sat Sep 7 09:21:35 PDT 2024


https://github.com/qedawkins updated https://github.com/llvm/llvm-project/pull/107302

>From 4fd39cc967697370f6cee2766cb93c6bc9957dd6 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Wed, 4 Sep 2024 16:08:05 -0400
Subject: [PATCH 1/2] [mlir][tensor] Add canonicalization to fold consecutive
 tensor.pad ops

`tensor.pad(tensor.pad)` with the same constant padding value can be
combined into a single pad that pads to the sum of the high and low
padding amounts.
---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp   | 85 +++++++++++++++++++++-
 mlir/test/Dialect/Tensor/canonicalize.mlir | 82 +++++++++++++++++++++
 2 files changed, 166 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 996de530c255d4..48f9aa0d0664a1 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3397,12 +3397,95 @@ struct FoldStaticPadding : public OpRewritePattern<PadOp> {
   }
 };
 
+/// Folds a chain of `tensor.pad` ops with the same constant padding value.
+///
+/// Example:
+///
+/// ```mlir
+///   %1 = tensor.pad %0 low[0, 1] high[0, 2] {
+///       tensor.yield %val
+///     } : tensor<1x2xf32> to tensor<2x5xf32>
+///   %res = tensor.pad %1 low[0, 2] high[3, 0] {
+///       tensor.yield %val
+///     } : tensor<1x5xf32> to tensor<5x7xf32>
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+///   %res = tensor.pad %0 low[0, 3] high[3, 2] {
+///       tensor.yield %val
+///     } : tensor<1x2xf32> to tensor<5x7xf32>
+/// ```
+struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
+  using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::PadOp padOp,
+                                PatternRewriter &rewriter) const override {
+    auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
+    if (!producerPad || producerPad.getNofold()) {
+      return rewriter.notifyMatchFailure(
+          padOp, "producer is not a foldable tensor.pad op");
+    }
+
+    // Fail if the tensor::PadOps padding values do not match.
+    Value consumerPadValue = padOp.getConstantPaddingValue();
+    Value producerPadValue = producerPad.getConstantPaddingValue();
+    if (!consumerPadValue || !producerPadValue ||
+        consumerPadValue != producerPadValue) {
+      return rewriter.notifyMatchFailure(
+          padOp,
+          "cannot fold PadOps with different or non-constant padding values");
+    }
+
+    Location loc = padOp.getLoc();
+
+    // Combine the low/high paddings of the two tensor::PadOps.
+    auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
+                           ArrayRef<OpFoldResult> producerPaddings) {
+      SmallVector<OpFoldResult> sumPaddings;
+      for (auto [consumerIndex, producerIndex] :
+           llvm::zip_equal(consumerPaddings, producerPaddings)) {
+        Value consumerIndexVal =
+            getValueOrCreateConstantIndexOp(rewriter, loc, consumerIndex);
+        Value producerIndexVal =
+            getValueOrCreateConstantIndexOp(rewriter, loc, producerIndex);
+        Value sum = rewriter.createOrFold<arith::AddIOp>(loc, consumerIndexVal,
+                                                         producerIndexVal);
+        APInt constantSum;
+        if (matchPattern(sum, m_ConstantInt(&constantSum))) {
+          sumPaddings.push_back(
+              rewriter.getIndexAttr(constantSum.getSExtValue()));
+        } else {
+          sumPaddings.push_back(sum);
+        }
+      }
+      return sumPaddings;
+    };
+
+    SmallVector<OpFoldResult> newHighPad =
+        addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
+    SmallVector<OpFoldResult> newLowPad =
+        addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
+
+    auto newPadOp = rewriter.create<tensor::PadOp>(
+        padOp.getLoc(), padOp.getResultType(), producerPad.getSource(),
+        newLowPad, newHighPad, padOp.getNofold(),
+        getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
+    rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
+                                newPadOp.getRegion().begin());
+    rewriter.replaceOp(padOp, newPadOp.getResult());
+    return success();
+  }
+};
+
 } // namespace
 
 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
   results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
-              FoldOrthogonalPaddings, FoldStaticPadding>(context);
+              FoldOrthogonalPaddings, FoldStaticPadding,
+              FoldConsecutiveConstantPadding>(context);
 }
 
 /// Return the padding value of the PadOp if it constant. In this context,
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 735790e5bd6c5e..fb9e6c96b09ef6 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1964,6 +1964,88 @@ func.func @dont_fold_pad_chains(%arg0: tensor<64x64xf32>,
 
 // -----
 
+// CHECK-LABEL: func @merge_constant_padding
+//  CHECK-SAME:   %[[ARG0:[A-Za-z0-9]+]]: tensor<2x3xf32>
+//  CHECK-SAME:   %[[PADVAL:[A-Za-z0-9]+]]: f32
+//       CHECK:   %[[PAD:.+]] = tensor.pad %[[ARG0]] low[1, 3] high[4, 2]
+//       CHECK:     tensor.yield %[[PADVAL]]
+//       CHECK:   return %[[PAD]]
+func.func @merge_constant_padding(%arg0: tensor<2x3xf32>, %pad_value: f32) -> tensor<7x8xf32> {
+  %pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] {
+    ^bb0(%b0: index, %b1 : index):
+      tensor.yield %pad_value : f32
+    } : tensor<2x3xf32> to tensor<4x4xf32>
+  %pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] {
+    ^bb0(%b2: index, %b3 : index):
+      tensor.yield %pad_value : f32
+    } : tensor<4x4xf32> to tensor<7x8xf32>
+  return %pad1 : tensor<7x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @merge_constant_padding_dynamic
+//  CHECK-SAME:   %[[ARG0:[A-Za-z0-9]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[IDX:[A-Za-z0-9]+]]: index
+//  CHECK-SAME:   %[[PADVAL:[A-Za-z0-9]+]]: f32
+//       CHECK:   %[[C1:.+]] = arith.constant 1 : index
+//       CHECK:   %[[HIGH:.+]] = arith.addi %[[IDX]], %[[C1]] : index
+//       CHECK:   %[[PAD:.+]] = tensor.pad %[[ARG0]] low[%[[IDX]], 3] high[%[[HIGH]], 2]
+//       CHECK:     tensor.yield %[[PADVAL]]
+//       CHECK:   return %[[PAD]]
+func.func @merge_constant_padding_dynamic(%arg0: tensor<?x?xf32>, %idx: index, %pad_value: f32) -> tensor<?x?xf32> {
+  %pad0 = tensor.pad %arg0 low[%idx, 1] high[1, 0] {
+    ^bb0(%b0: index, %b1 : index):
+      tensor.yield %pad_value : f32
+    } : tensor<?x?xf32> to tensor<?x?xf32>
+  %pad1 = tensor.pad %pad0 low[0, 2] high[%idx, 2] {
+    ^bb0(%b2: index, %b3 : index):
+      tensor.yield %pad_value : f32
+    } : tensor<?x?xf32> to tensor<?x?xf32>
+  return %pad1 : tensor<?x?xf32>
+}
+
+// -----
+
+// Verify that folding does not happen if it would drop a nofold attribute
+// CHECK-LABEL: func @dont_merge_constant_padding_nofold
+//       CHECK:   tensor.pad {{.*}} nofold
+//       CHECK:   tensor.pad
+func.func @dont_merge_constant_padding_nofold(%arg0: tensor<2x3xf32>, %pad_value: f32) -> tensor<7x8xf32> {
+  %pad0 = tensor.pad %arg0 nofold low[1, 1] high[1, 0] {
+    ^bb0(%b0: index, %b1 : index):
+      tensor.yield %pad_value : f32
+    } : tensor<2x3xf32> to tensor<4x4xf32>
+  %pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] {
+    ^bb0(%b2: index, %b3 : index):
+      tensor.yield %pad_value : f32
+    } : tensor<4x4xf32> to tensor<7x8xf32>
+  return %pad1 : tensor<7x8xf32>
+}
+
+// -----
+
+// Verify that folding does not happen if it would drop a nofold attribute
+// CHECK-LABEL: func @dont_merge_constant_padding_different_vals
+//       CHECK:   tensor.pad
+//       CHECK:   tensor.pad
+func.func @dont_merge_constant_padding_different_vals(
+    %arg0: tensor<2x3xf32>,
+    %pad_value0: f32,
+    %pad_value1: f32) -> tensor<7x8xf32> {
+  %pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] {
+    ^bb0(%b0: index, %b1 : index):
+      tensor.yield %pad_value0 : f32
+    } : tensor<2x3xf32> to tensor<4x4xf32>
+  %pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] {
+    ^bb0(%b2: index, %b3 : index):
+      tensor.yield %pad_value1 : f32
+    } : tensor<4x4xf32> to tensor<7x8xf32>
+  return %pad1 : tensor<7x8xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_collapse_shape_from_elements
 func.func @fold_collapse_shape_from_elements(%arg0: i32) -> tensor<i32> {
   // CHECK: %[[FROM:.+]] = tensor.from_elements %arg0 : tensor<i32>

>From 87d638f2247dc0b725ecd3f24cc957b7fa204a03 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Sat, 7 Sep 2024 12:21:22 -0400
Subject: [PATCH 2/2] add back nofold check

---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 48f9aa0d0664a1..16970d8c304653 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3422,6 +3422,10 @@ struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
 
   LogicalResult matchAndRewrite(tensor::PadOp padOp,
                                 PatternRewriter &rewriter) const override {
+    if (padOp.getNofold()) {
+      return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
+    }
+
     auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
     if (!producerPad || producerPad.getNofold()) {
       return rewriter.notifyMatchFailure(



More information about the Mlir-commits mailing list