[Mlir-commits] [mlir] [mlir][tensor] Add canonicalization to fold consecutive tensor.pad ops (PR #107302)

Quinn Dawkins llvmlistbot at llvm.org
Mon Sep 9 07:55:42 PDT 2024


================
@@ -3397,12 +3397,95 @@ struct FoldStaticPadding : public OpRewritePattern<PadOp> {
   }
 };
 
+/// Folds a chain of `tensor.pad` ops with the same constant padding value.
+///
+/// Example:
+///
+/// ```mlir
+///   %1 = tensor.pad %0 low[0, 1] high[0, 2] {
+///       tensor.yield %val
+///     } : tensor<1x2xf32> to tensor<2x5xf32>
+///   %res = tensor.pad %1 low[0, 2] high[3, 0] {
+///       tensor.yield %val
+///     } : tensor<1x5xf32> to tensor<5x7xf32>
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+///   %res = tensor.pad %0 low[0, 3] high[3, 2] {
+///       tensor.yield %val
+///     } : tensor<1x2xf32> to tensor<5x7xf32>
+/// ```
+struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
+  using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::PadOp padOp,
+                                PatternRewriter &rewriter) const override {
+    auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
+    if (!producerPad || producerPad.getNofold()) {
+      return rewriter.notifyMatchFailure(
+          padOp, "producer is not a foldable tensor.pad op");
+    }
+
+    // Fail if the tensor::PadOps padding values do not match.
+    Value consumerPadValue = padOp.getConstantPaddingValue();
+    Value producerPadValue = producerPad.getConstantPaddingValue();
+    if (!consumerPadValue || !producerPadValue ||
+        consumerPadValue != producerPadValue) {
+      return rewriter.notifyMatchFailure(
+          padOp,
+          "cannot fold PadOps with different or non-constant padding values");
+    }
+
+    Location loc = padOp.getLoc();
+
+    // Combine the low/high paddings of the two tensor::PadOps.
+    auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
+                           ArrayRef<OpFoldResult> producerPaddings) {
+      SmallVector<OpFoldResult> sumPaddings;
+      for (auto [consumerIndex, producerIndex] :
+           llvm::zip_equal(consumerPaddings, producerPaddings)) {
+        Value consumerIndexVal =
+            getValueOrCreateConstantIndexOp(rewriter, loc, consumerIndex);
+        Value producerIndexVal =
+            getValueOrCreateConstantIndexOp(rewriter, loc, producerIndex);
+        Value sum = rewriter.createOrFold<arith::AddIOp>(loc, consumerIndexVal,
+                                                         producerIndexVal);
+        APInt constantSum;
+        if (matchPattern(sum, m_ConstantInt(&constantSum))) {
----------------
qedawkins wrote:

ok I checked other canonicalizer examples in this file and looks like Affine is already used, so I switched to `makeComposedFoldedAffineApply` like you suggested.

https://github.com/llvm/llvm-project/pull/107302


More information about the Mlir-commits mailing list