[Mlir-commits] [mlir] [mlir] Add reshape propagation patterns for tensor.pad (PR #94489)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 6 09:33:15 PDT 2024


================
@@ -1702,6 +1760,80 @@ class FoldWithProducerReshapeOpByCollapsing
   ControlFusionFn controlFoldingReshapes;
 };
 
+class FoldPadWithProducerReshapeOpByCollapsing
+    : public OpRewritePattern<tensor::PadOp> {
+public:
+  FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
+                                           ControlFusionFn foldReshapes,
+                                           PatternBenefit benefit = 1)
+      : OpRewritePattern<tensor::PadOp>(context, benefit),
+        controlFoldingReshapes(std::move(foldReshapes)) {}
+
+  LogicalResult matchAndRewrite(tensor::PadOp padOp,
+                                PatternRewriter &rewriter) const override {
+    tensor::ExpandShapeOp reshapeOp =
+        padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
+    if (!reshapeOp)
+      return failure();
+    if (!reshapeOp->hasOneUse())
+      return failure();
+
+    ArrayRef<int64_t> low = padOp.getStaticLow();
+    ArrayRef<int64_t> high = padOp.getStaticHigh();
+    SmallVector<ReassociationIndices> reassociations =
+        reshapeOp.getReassociationIndices();
+
+    for (auto reInd : reassociations) {
+      if (reInd.size() == 1)
+        continue;
+      if (llvm::any_of(reInd, [&](int64_t ind) {
+            return low[ind] != 0 || high[ind] != 0;
+          })) {
+        return failure();
+      }
+    }
+
+    SmallVector<OpFoldResult> newLow, newHigh;
----------------
Max191 wrote:

They are different enough that I think it is a little difficult to reuse much implementation. They are both propagating the reshapes down, just different reshapes.

https://github.com/llvm/llvm-project/pull/94489


More information about the Mlir-commits mailing list