[Mlir-commits] [mlir] [mlir][tosa] Canonicalise slice over overlapped or inside a pad. (PR #138270)
Luke Hutton
llvmlistbot at llvm.org
Tue May 6 05:19:52 PDT 2025
================
@@ -731,6 +731,127 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
}
};
+struct PadSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
+ using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
+ PatternRewriter &rewriter) const override {
+ Value sliceInput = sliceOp.getInput1();
+
+ // Check if producer is a PadOp
+ auto padOp = sliceInput.getDefiningOp<tosa::PadOp>();
+ if (!padOp)
+ return rewriter.notifyMatchFailure(sliceOp,
+ "slice input must be a pad operation");
+
+ // Check PadOp has a single consumer
+ if (!padOp->hasOneUse())
+ return rewriter.notifyMatchFailure(sliceOp,
+ "pad shall have a single consumer");
+
+ // Check input is statically ranked
+ auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
+ auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
+ if (!inputTy || !padTy)
+ return rewriter.notifyMatchFailure(
+ sliceOp, "slice input must be a static ranked tensor");
+
+ // Validate and extract tosa::PadOp padding
+ DenseIntElementsAttr paddingElems;
+ if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
+ return rewriter.notifyMatchFailure(
+ sliceOp,
+ "The `padding` input specified on the tosa::PadOp must be constant.");
+ }
+ llvm::SmallVector<int64_t> padPaddings =
+ llvm::to_vector(paddingElems.getValues<int64_t>());
+
+ // Extract slice parameters
+ DenseElementsAttr startElems;
+ if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
+ return rewriter.notifyMatchFailure(
+ sliceOp, "start of slice must be a static ranked shape");
+ llvm::SmallVector<int64_t> sliceStarts =
+ llvm::to_vector(startElems.getValues<int64_t>());
+
+ DenseElementsAttr sizeElems;
+ if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
+ return rewriter.notifyMatchFailure(
+ sliceOp, "size of slice must be a static ranked shape");
+ llvm::SmallVector<int64_t> sliceSizes =
+ llvm::to_vector(sizeElems.getValues<int64_t>());
+
+ // Update the paddings
+ int64_t rank = inputTy.getRank();
+ llvm::SmallVector<int64_t> newSliceStarts(rank, 0);
+ llvm::SmallVector<int64_t> newPadPaddings(2 * rank, 0);
+ llvm::SmallVector<int64_t> newPadShape(rank, 0);
+ bool updated = false;
+ for (int64_t i = 0; i < rank; ++i) {
+ const int64_t padLo = padPaddings[i * 2];
+ const int64_t padHi = padPaddings[i * 2 + 1];
+ const int64_t sliceStart = sliceStarts[i];
+ const int64_t sliceSize = sliceSizes[i];
+ const int64_t sliceEnd = sliceStart + sliceSize;
+
+ const int64_t dimSize = inputTy.getShape()[i];
+ const int64_t dimStart = padLo;
+ const int64_t dimEnd = padLo + dimSize;
+ const int64_t dimTotal = padLo + dimSize + padHi;
+
+ // Check slice within bounds
+ if (sliceStart < 0 || sliceEnd > dimTotal)
+ return rewriter.notifyMatchFailure(sliceOp, "slice out-of-bounds");
+
+ const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
+ const int64_t newPadHi =
+ std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
----------------
lhutton1 wrote:
nit: I think these can be moved under `if (sliceStart < dimStart || sliceEnd > dimEnd)`
https://github.com/llvm/llvm-project/pull/138270
More information about the Mlir-commits
mailing list