[Mlir-commits] [mlir] [mlir][tensor] Make tensor::PadOp a ReifyRankedShapedTypeOpInterface and add a PadOp::FoldReifiedShape canonicalization (PR #145732)
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Jun 27 00:06:43 PDT 2025
https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/145732
>From eb336546d915d54db9115876a61e07c595a78637 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Wed, 25 Jun 2025 16:46:44 +0200
Subject: [PATCH] [mlir][tensor] Add a PadOp::FoldReifiedShape canonicalization
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 43 +++++++++++++++++++++-
mlir/test/Dialect/Tensor/canonicalize.mlir | 18 +++++++++
2 files changed, 60 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 22a25fd1a5af8..c3147e297e2ff 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3791,6 +3791,47 @@ struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
}
};
+struct FoldReifiedShape : public OpRewritePattern<tensor::PadOp> {
+ using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::PadOp padOp,
+ PatternRewriter &rewriter) const override {
+ if (padOp.getNofold()) {
+ return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
+ }
+
+ ReifiedRankedShapedTypeDims reifiedResultShapes;
+ if (failed(reifyResultShapes(rewriter, padOp, reifiedResultShapes)))
+ return failure();
+
+ SmallVector<int64_t> newShape;
+ for (const auto &[s, ofr] : llvm::zip_equal(
+ padOp.getResultType().getShape(), reifiedResultShapes.front())) {
+ std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
+ // Reification does not add static information, just use existing shape.
+ if (!maybeCst.has_value()) {
+ newShape.push_back(s);
+ continue;
+ }
+ int64_t cst = *maybeCst;
+ assert((ShapedType::isDynamic(s) || s == cst) && "constants must agree!");
+ newShape.push_back(cst);
+ }
+ if (newShape == padOp.getResultType().getShape())
+ return failure();
+
+ Type oldType = padOp.getResultType();
+ Type newType =
+ RankedTensorType::Builder(padOp.getResultType()).setShape(newShape);
+ Location loc = padOp->getLoc();
+ Operation *newPad = rewriter.clone(*padOp);
+ newPad->getResult(0).setType(newType);
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(padOp, oldType,
+ newPad->getResult(0));
+ return success();
+ }
+};
+
} // namespace
LogicalResult
@@ -3820,7 +3861,7 @@ void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
FoldOrthogonalPaddings, FoldStaticPadding,
- FoldConsecutiveConstantPadding>(context);
+ FoldConsecutiveConstantPadding, FoldReifiedShape>(context);
}
/// Return the padding value of the PadOp if it constant. In this context,
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 3f9236095138b..2a42a9a810ec4 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2561,3 +2561,21 @@ func.func @partial_sink_expand_of_cast(%arg0 : tensor<10x10xf32>, %arg1 : index,
// CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]]
// CHECK-SAME: tensor<?x?x10xf32> to tensor<?x?x?xf32>
// CHECK: return %[[RES]]
+
+// -----
+
+// CHECK-LABEL: func.func @pad_reification
+func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>)
+ -> tensor<1x?x64xf32> {
+ %pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx)
+ %es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+
+// CHECK: tensor.pad
+// CHECK: : tensor<1x?x64xf32> to tensor<1x256x64xf32>
+ %padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] {
+ ^bb0(%a: index, %b: index, %c: index):
+ tensor.yield %cst : f32
+ } : tensor<1x?x64xf32> to tensor<1x?x64xf32>
+
+ return %padded : tensor<1x?x64xf32>
+}
More information about the Mlir-commits
mailing list