[Mlir-commits] [mlir] [mlir][tensor] Add canonicalization to fold consecutive tensor.pad ops (PR #107302)
Quinn Dawkins
llvmlistbot at llvm.org
Thu Sep 5 14:04:51 PDT 2024
================
@@ -3397,12 +3397,95 @@ struct FoldStaticPadding : public OpRewritePattern<PadOp> {
}
};
+/// Folds a chain of `tensor.pad` ops with the same constant padding value.
+///
+/// Example:
+///
+/// ```mlir
+/// %1 = tensor.pad %0 low[0, 1] high[0, 2] {
+/// tensor.yield %val
+/// } : tensor<1x2xf32> to tensor<2x5xf32>
+/// %res = tensor.pad %1 low[0, 2] high[3, 0] {
+/// tensor.yield %val
+/// } : tensor<1x5xf32> to tensor<5x7xf32>
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+/// %res = tensor.pad %0 low[0, 3] high[3, 2] {
+/// tensor.yield %val
+/// } : tensor<1x2xf32> to tensor<5x7xf32>
+/// ```
+struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
+ using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::PadOp padOp,
+ PatternRewriter &rewriter) const override {
+ auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
+ if (!producerPad || producerPad.getNofold()) {
+ return rewriter.notifyMatchFailure(
+ padOp, "producer is not a foldable tensor.pad op");
+ }
----------------
qedawkins wrote:
I initially had that, but there is another similar canonicalizer that only checks that the producer pad is foldable, i.e. if `padOp` is not foldable, that's fine because we're folding the producer pad into `padOp` and `padOp` is not getting folded. I think `noFold` is a dubious attribute anyway, so someone who has a stake in how `noFold` works is welcome to change it to the behavior they want.
https://github.com/llvm/llvm-project/pull/107302
More information about the Mlir-commits
mailing list