[Mlir-commits] [mlir] [mlir][tosa] Canonicalise slice over overlapped or inside a pad. (PR #138270)

Luke Hutton llvmlistbot at llvm.org
Tue May 6 05:19:52 PDT 2025


================
@@ -731,6 +731,127 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
   }
 };
 
+struct PadSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
+  using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
+                                PatternRewriter &rewriter) const override {
+    Value sliceInput = sliceOp.getInput1();
+
+    // Check if producer is a PadOp
+    auto padOp = sliceInput.getDefiningOp<tosa::PadOp>();
+    if (!padOp)
+      return rewriter.notifyMatchFailure(sliceOp,
+                                         "slice input must be a pad operation");
+
+    // Check PadOp has a single consumer
+    if (!padOp->hasOneUse())
+      return rewriter.notifyMatchFailure(sliceOp,
+                                         "pad shall have a single consumer");
+
+    // Check input is statically ranked
+    auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
+    auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
+    if (!inputTy || !padTy)
+      return rewriter.notifyMatchFailure(
+          sliceOp, "slice input must be a static ranked tensor");
+
+    // Validate and extract tosa::PadOp padding
+    DenseIntElementsAttr paddingElems;
+    if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
+      return rewriter.notifyMatchFailure(
+          sliceOp,
+          "The `padding` input specified on the tosa::PadOp must be constant.");
+    }
+    llvm::SmallVector<int64_t> padPaddings =
+        llvm::to_vector(paddingElems.getValues<int64_t>());
+
+    // Extract slice parameters
+    DenseElementsAttr startElems;
+    if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
+      return rewriter.notifyMatchFailure(
+          sliceOp, "start of slice must be a static ranked shape");
+    llvm::SmallVector<int64_t> sliceStarts =
+        llvm::to_vector(startElems.getValues<int64_t>());
+
+    DenseElementsAttr sizeElems;
+    if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
+      return rewriter.notifyMatchFailure(
+          sliceOp, "size of slice must be a static ranked shape");
+    llvm::SmallVector<int64_t> sliceSizes =
+        llvm::to_vector(sizeElems.getValues<int64_t>());
+
+    // Update the paddings
+    int64_t rank = inputTy.getRank();
+    llvm::SmallVector<int64_t> newSliceStarts(rank, 0);
+    llvm::SmallVector<int64_t> newPadPaddings(2 * rank, 0);
+    llvm::SmallVector<int64_t> newPadShape(rank, 0);
+    bool updated = false;
+    for (int64_t i = 0; i < rank; ++i) {
+      const int64_t padLo = padPaddings[i * 2];
+      const int64_t padHi = padPaddings[i * 2 + 1];
+      const int64_t sliceStart = sliceStarts[i];
+      const int64_t sliceSize = sliceSizes[i];
+      const int64_t sliceEnd = sliceStart + sliceSize;
+
+      const int64_t dimSize = inputTy.getShape()[i];
+      const int64_t dimStart = padLo;
+      const int64_t dimEnd = padLo + dimSize;
+      const int64_t dimTotal = padLo + dimSize + padHi;
+
+      // Check slice within bounds
+      if (sliceStart < 0 || sliceEnd > dimTotal)
+        return rewriter.notifyMatchFailure(sliceOp, "slice out-of-bounds");
+
+      const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
+      const int64_t newPadHi =
+          std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
+      const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
+
+      // Compute update slice/pad parameters
+      if (sliceStart < dimStart || sliceEnd > dimEnd) {
+        // Handle slice when not within the original input entirely
+        updated |= (newPadLo != padLo) || (newPadHi != padHi) ||
+                   (newSliceStart != sliceStart);
+        newPadPaddings[i * 2] = newPadLo;
+        newPadPaddings[i * 2 + 1] = newPadHi;
+        newSliceStarts[i] = newSliceStart;
+      } else {
+        // Slice is within the original input
+        updated |= newSliceStart != sliceStart;
----------------
lhutton1 wrote:

I think we could remove this else statement if this update is moved out as well

https://github.com/llvm/llvm-project/pull/138270


More information about the Mlir-commits mailing list