[Mlir-commits] [mlir] [mlir][tosa] Canonicalise slice over overlapped or inside a pad. (PR #138270)

Luke Hutton llvmlistbot at llvm.org
Wed May 7 04:58:53 PDT 2025


================
@@ -731,6 +731,127 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
   }
 };
 
+struct PadSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
+  using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
+                                PatternRewriter &rewriter) const override {
+    Value sliceInput = sliceOp.getInput1();
+
+    // Check if producer is a PadOp
+    auto padOp = sliceInput.getDefiningOp<tosa::PadOp>();
+    if (!padOp)
+      return rewriter.notifyMatchFailure(sliceOp,
+                                         "slice input must be a pad operation");
+
+    // Check PadOp has a single consumer
+    if (!padOp->hasOneUse())
+      return rewriter.notifyMatchFailure(sliceOp,
+                                         "pad shall have a single consumer");
+
+    // Check input is statically ranked
+    auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
+    auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
+    if (!inputTy || !padTy)
+      return rewriter.notifyMatchFailure(
+          sliceOp, "slice input must be a static ranked tensor");
+
+    // Validate and extract tosa::PadOp padding
+    DenseIntElementsAttr paddingElems;
+    if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
+      return rewriter.notifyMatchFailure(
+          sliceOp,
+          "The `padding` input specified on the tosa::PadOp must be constant.");
+    }
+    llvm::SmallVector<int64_t> padPaddings =
+        llvm::to_vector(paddingElems.getValues<int64_t>());
+
+    // Extract slice parameters
+    DenseElementsAttr startElems;
+    if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
+      return rewriter.notifyMatchFailure(
+          sliceOp, "start of slice must be a static ranked shape");
+    llvm::SmallVector<int64_t> sliceStarts =
+        llvm::to_vector(startElems.getValues<int64_t>());
+
+    DenseElementsAttr sizeElems;
+    if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
+      return rewriter.notifyMatchFailure(
+          sliceOp, "size of slice must be a static ranked shape");
+    llvm::SmallVector<int64_t> sliceSizes =
+        llvm::to_vector(sizeElems.getValues<int64_t>());
+
+    // Update the paddings
+    int64_t rank = inputTy.getRank();
+    llvm::SmallVector<int64_t> newSliceStarts(rank, 0);
+    llvm::SmallVector<int64_t> newPadPaddings(2 * rank, 0);
+    llvm::SmallVector<int64_t> newPadShape(rank, 0);
+    bool updated = false;
+    for (int64_t i = 0; i < rank; ++i) {
+      const int64_t padLo = padPaddings[i * 2];
+      const int64_t padHi = padPaddings[i * 2 + 1];
----------------
lhutton1 wrote:

I missed the following use below https://github.com/llvm/llvm-project/pull/138270/files#diff-a09923d3e3c3e2c9cfa17a22fd05346bd8e3caa2c7332e64cba7f51f2748d916R798, please ignore

https://github.com/llvm/llvm-project/pull/138270


More information about the Mlir-commits mailing list