[Mlir-commits] [mlir] [mlir][tensor] Add canonicalization to fold consecutive tensor.pad ops (PR #107302)
Quinn Dawkins
llvmlistbot at llvm.org
Thu Sep 5 11:29:33 PDT 2024
================
@@ -3397,12 +3397,95 @@ struct FoldStaticPadding : public OpRewritePattern<PadOp> {
}
};
+/// Folds a chain of `tensor.pad` ops with the same constant padding value.
+///
+/// Example:
+///
+/// ```mlir
+/// %1 = tensor.pad %0 low[0, 1] high[0, 2] {
+/// tensor.yield %val
+/// } : tensor<1x2xf32> to tensor<2x5xf32>
+/// %res = tensor.pad %1 low[0, 2] high[3, 0] {
+/// tensor.yield %val
+/// } : tensor<1x5xf32> to tensor<5x7xf32>
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+/// %res = tensor.pad %0 low[0, 3] high[3, 2] {
+/// tensor.yield %val
+/// } : tensor<1x2xf32> to tensor<5x7xf32>
+/// ```
+struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
+ using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::PadOp padOp,
+ PatternRewriter &rewriter) const override {
+ auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
+ if (!producerPad || producerPad.getNofold()) {
+ return rewriter.notifyMatchFailure(
+ padOp, "producer is not a foldable tensor.pad op");
+ }
+
+ // Fail if the tensor::PadOps padding values do not match.
+ Value consumerPadValue = padOp.getConstantPaddingValue();
+ Value producerPadValue = producerPad.getConstantPaddingValue();
+ if (!consumerPadValue || !producerPadValue ||
+ consumerPadValue != producerPadValue) {
+ return rewriter.notifyMatchFailure(
+ padOp,
+ "cannot fold PadOps with different or non-constant padding values");
+ }
+
+ Location loc = padOp.getLoc();
+
+ // Combine the low/high paddings of the two tensor::PadOps.
+ auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
+ ArrayRef<OpFoldResult> producerPaddings) {
+ SmallVector<OpFoldResult> sumPaddings;
+ for (auto [consumerIndex, producerIndex] :
+ llvm::zip_equal(consumerPaddings, producerPaddings)) {
+ Value consumerIndexVal =
+ getValueOrCreateConstantIndexOp(rewriter, loc, consumerIndex);
+ Value producerIndexVal =
+ getValueOrCreateConstantIndexOp(rewriter, loc, producerIndex);
+ Value sum = rewriter.createOrFold<arith::AddIOp>(loc, consumerIndexVal,
+ producerIndexVal);
+ APInt constantSum;
+ if (matchPattern(sum, m_ConstantInt(&constantSum))) {
----------------
qedawkins wrote:
I avoided that because I thought Tensor couldn't take a dependency on Affine (and also means that the CanonicalizerPass would need to register the Affine dialect IIUC), but if that dep is fine then I'd be happy to use that helper.
https://github.com/llvm/llvm-project/pull/107302
More information about the Mlir-commits
mailing list