[Mlir-commits] [mlir] [mlir] Add reshape propagation patterns for tensor.pad (PR #94489)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 5 13:15:03 PDT 2024


================
@@ -1702,6 +1760,80 @@ class FoldWithProducerReshapeOpByCollapsing
   ControlFusionFn controlFoldingReshapes;
 };
 
+class FoldPadWithProducerReshapeOpByCollapsing
+    : public OpRewritePattern<tensor::PadOp> {
+public:
+  FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context,
+                                           ControlFusionFn foldReshapes,
+                                           PatternBenefit benefit = 1)
+      : OpRewritePattern<tensor::PadOp>(context, benefit),
+        controlFoldingReshapes(std::move(foldReshapes)) {}
+
+  LogicalResult matchAndRewrite(tensor::PadOp padOp,
+                                PatternRewriter &rewriter) const override {
+    tensor::ExpandShapeOp reshapeOp =
+        padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
+    if (!reshapeOp)
+      return failure();
+    if (!reshapeOp->hasOneUse())
+      return failure();
+
+    ArrayRef<int64_t> low = padOp.getStaticLow();
+    ArrayRef<int64_t> high = padOp.getStaticHigh();
+    SmallVector<ReassociationIndices> reassociations =
+        reshapeOp.getReassociationIndices();
+
+    for (auto reInd : reassociations) {
+      if (reInd.size() == 1)
+        continue;
+      if (llvm::any_of(reInd, [&](int64_t ind) {
+            return low[ind] != 0 || high[ind] != 0;
+          })) {
+        return failure();
+      }
+    }
+
+    SmallVector<OpFoldResult> newLow, newHigh;
----------------
MaheshRavishankar wrote:

Could we reuse some of the implementation. Propagating expand_shape up or collapse_shape down are essentially the same. 

https://github.com/llvm/llvm-project/pull/94489


More information about the Mlir-commits mailing list