[Mlir-commits] [mlir] [mlir][tosa] Canonicalise slice over overlapped or inside a pad. (PR #138270)
Luke Hutton
llvmlistbot at llvm.org
Wed May 7 04:58:53 PDT 2025
================
@@ -731,6 +731,127 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
}
};
+struct PadSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
+ using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
+ PatternRewriter &rewriter) const override {
+ Value sliceInput = sliceOp.getInput1();
+
+ // Check if producer is a PadOp
+ auto padOp = sliceInput.getDefiningOp<tosa::PadOp>();
+ if (!padOp)
+ return rewriter.notifyMatchFailure(sliceOp,
+ "slice input must be a pad operation");
+
+ // Check PadOp has a single consumer
+ if (!padOp->hasOneUse())
+ return rewriter.notifyMatchFailure(sliceOp,
+ "pad shall have a single consumer");
+
+ // Check input is statically ranked
+ auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
+ auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
+ if (!inputTy || !padTy)
+ return rewriter.notifyMatchFailure(
+ sliceOp, "slice input must be a static ranked tensor");
+
+ // Validate and extract tosa::PadOp padding
+ DenseIntElementsAttr paddingElems;
+ if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
+ return rewriter.notifyMatchFailure(
+ sliceOp,
+ "The `padding` input specified on the tosa::PadOp must be constant.");
+ }
+ llvm::SmallVector<int64_t> padPaddings =
+ llvm::to_vector(paddingElems.getValues<int64_t>());
+
+ // Extract slice parameters
+ DenseElementsAttr startElems;
+ if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
+ return rewriter.notifyMatchFailure(
+ sliceOp, "start of slice must be a static ranked shape");
+ llvm::SmallVector<int64_t> sliceStarts =
+ llvm::to_vector(startElems.getValues<int64_t>());
+
+ DenseElementsAttr sizeElems;
+ if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
+ return rewriter.notifyMatchFailure(
+ sliceOp, "size of slice must be a static ranked shape");
+ llvm::SmallVector<int64_t> sliceSizes =
+ llvm::to_vector(sizeElems.getValues<int64_t>());
+
+ // Update the paddings
+ int64_t rank = inputTy.getRank();
+ llvm::SmallVector<int64_t> newSliceStarts(rank, 0);
+ llvm::SmallVector<int64_t> newPadPaddings(2 * rank, 0);
+ llvm::SmallVector<int64_t> newPadShape(rank, 0);
+ bool updated = false;
+ for (int64_t i = 0; i < rank; ++i) {
+ const int64_t padLo = padPaddings[i * 2];
+ const int64_t padHi = padPaddings[i * 2 + 1];
----------------
lhutton1 wrote:
I missed the following use below https://github.com/llvm/llvm-project/pull/138270/files#diff-a09923d3e3c3e2c9cfa17a22fd05346bd8e3caa2c7332e64cba7f51f2748d916R798, please ignore
https://github.com/llvm/llvm-project/pull/138270
More information about the Mlir-commits
mailing list