[Mlir-commits] [mlir] [mlir] Add reshape propagation patterns for tensor.pad (PR #94489)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 6 09:33:15 PDT 2024
================
@@ -1702,6 +1760,80 @@ class FoldWithProducerReshapeOpByCollapsing
ControlFusionFn controlFoldingReshapes;
};
+class FoldPadWithProducerReshapeOpByCollapsing
+ : public OpRewritePattern<tensor::PadOp> {
+public:
+ FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
+ ControlFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::PadOp>(context, benefit),
+ controlFoldingReshapes(std::move(foldReshapes)) {}
+
+ LogicalResult matchAndRewrite(tensor::PadOp padOp,
+ PatternRewriter &rewriter) const override {
+ tensor::ExpandShapeOp reshapeOp =
+ padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
+ if (!reshapeOp)
+ return failure();
+ if (!reshapeOp->hasOneUse())
+ return failure();
+
+ ArrayRef<int64_t> low = padOp.getStaticLow();
+ ArrayRef<int64_t> high = padOp.getStaticHigh();
+ SmallVector<ReassociationIndices> reassociations =
+ reshapeOp.getReassociationIndices();
+
+ for (auto reInd : reassociations) {
+ if (reInd.size() == 1)
+ continue;
+ if (llvm::any_of(reInd, [&](int64_t ind) {
+ return low[ind] != 0 || high[ind] != 0;
+ })) {
+ return failure();
+ }
+ }
+
+ SmallVector<OpFoldResult> newLow, newHigh;
----------------
Max191 wrote:
They are different enough that I think it is a little difficult to reuse much implementation. They are both propagating the reshapes down, just different reshapes.
https://github.com/llvm/llvm-project/pull/94489
More information about the Mlir-commits
mailing list