[Mlir-commits] [mlir] [mlir][tosa] Canonicalise slice over overlapped or inside a pad. (PR #138270)
Georgios Pinitas
llvmlistbot at llvm.org
Fri May 2 06:15:33 PDT 2025
https://github.com/GeorgeARM created https://github.com/llvm/llvm-project/pull/138270
Update the paddings and/or the slice parameters when a `tosa.slice` after a `tosa.pad` is accessing only an overlapping or not region of the padded tensor.
>From 64d81c8386d7463f38fdcc42e70f6a62bfbf1567 Mon Sep 17 00:00:00 2001
From: Georgios Pinitas <georgios.pinitas at arm.com>
Date: Thu, 1 May 2025 16:02:58 +0100
Subject: [PATCH] [mlir][tosa] Canonicalize slice over overlapped or inside a
pad.
Update the paddings and/or the slice parameters when a `tosa.slice` after a
`tosa.pad` is accessing only an overlapping or not region of the padded
tensor.
Signed-off-by: Georgios Pinitas <georgios.pinitas at arm.com>
---
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 125 +++++++++++++++++-
mlir/test/Dialect/Tosa/canonicalize.mlir | 36 +++++
2 files changed, 159 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 47368532df169..5347fb1c16698 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -731,6 +731,127 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
}
};
+struct PadSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
+ using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
+ PatternRewriter &rewriter) const override {
+ Value sliceInput = sliceOp.getInput1();
+
+ // Check if producer is a PadOp
+ auto padOp = sliceInput.getDefiningOp<tosa::PadOp>();
+ if (!padOp)
+ return rewriter.notifyMatchFailure(sliceOp,
+ "slice input must be a pad operation");
+
+ // Check PadOp has a single consumer
+ if (!padOp->hasOneUse())
+ return rewriter.notifyMatchFailure(sliceOp,
+ "pad shall have a single consumer");
+
+ // Check input is statically ranked
+ auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
+ auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
+ if (!inputTy || !padTy)
+ return rewriter.notifyMatchFailure(
+ sliceOp, "slice input must be a static ranked tensor");
+
+ // Validate and extract tosa::PadOp padding
+ DenseIntElementsAttr paddingElems;
+ if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
+ return rewriter.notifyMatchFailure(
+ sliceOp,
+ "The `padding` input specified on the tosa::PadOp must be constant.");
+ }
+ llvm::SmallVector<int64_t> padPaddings =
+ llvm::to_vector(paddingElems.getValues<int64_t>());
+
+ // Extract slice parameters
+ DenseElementsAttr startElems;
+ if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
+ return rewriter.notifyMatchFailure(
+ sliceOp, "start of slice must be a static ranked shape");
+ llvm::SmallVector<int64_t> sliceStarts =
+ llvm::to_vector(startElems.getValues<int64_t>());
+
+ DenseElementsAttr sizeElems;
+ if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
+ return rewriter.notifyMatchFailure(
+ sliceOp, "size of slice must be a static ranked shape");
+ llvm::SmallVector<int64_t> sliceSizes =
+ llvm::to_vector(sizeElems.getValues<int64_t>());
+
+ // Update the paddings
+ int64_t rank = inputTy.getRank();
+ llvm::SmallVector<int64_t> newSliceStarts(rank, 0);
+ llvm::SmallVector<int64_t> newPadPaddings(2 * rank, 0);
+ llvm::SmallVector<int64_t> newPadShape(rank, 0);
+ bool updated = false;
+ for (int64_t i = 0; i < rank; ++i) {
+ const int64_t padLo = padPaddings[i * 2];
+ const int64_t padHi = padPaddings[i * 2 + 1];
+ const int64_t sliceStart = sliceStarts[i];
+ const int64_t sliceSize = sliceSizes[i];
+ const int64_t sliceEnd = sliceStart + sliceSize;
+
+ const int64_t dimSize = inputTy.getShape()[i];
+ const int64_t dimStart = padLo;
+ const int64_t dimEnd = padLo + dimSize;
+ const int64_t dimTotal = padLo + dimSize + padHi;
+
+ // Check slice within bounds
+ if (sliceStart < 0 || sliceEnd > dimTotal)
+ return rewriter.notifyMatchFailure(sliceOp, "slice out-of-bounds");
+
+ const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
+ const int64_t newPadHi =
+ std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
+ const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
+
+ // Compute update slice/pad parameters
+ if (sliceStart < dimStart || sliceEnd > dimEnd) {
+ // Handle slice when not within the original input entirely
+ updated |= (newPadLo != padLo) || (newPadHi != padHi) ||
+ (newSliceStart != sliceStart);
+ newPadPaddings[i * 2] = newPadLo;
+ newPadPaddings[i * 2 + 1] = newPadHi;
+ newSliceStarts[i] = newSliceStart;
+ } else {
+ // Slice is within the original input
+ updated |= newSliceStart != sliceStart;
+ newSliceStarts[i] = newSliceStart;
+ }
+
+ // Calculate new pad output shape
+ newPadShape[i] =
+ newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
+ }
+
+ // Check that we actually need to proceed with the rewrite
+ if (!updated)
+ return rewriter.notifyMatchFailure(
+ sliceOp, "terminate condition; nothing to rewrite");
+
+ // Create a PadOp with updated padding
+ auto newPaddingsOp =
+ getTosaConstShape(rewriter, sliceOp.getLoc(), newPadPaddings);
+ auto newPadTy =
+ RankedTensorType::get(newPadShape, inputTy.getElementType());
+ auto newPadOp = rewriter.create<tosa::PadOp>(
+ padOp.getLoc(), newPadTy, padOp.getInput1(), newPaddingsOp,
+ padOp.getPadConst());
+
+ // Update SliceOp and point to new PadOp
+ auto newStartOp =
+ getTosaConstShape(rewriter, sliceOp.getLoc(), newSliceStarts);
+ rewriter.replaceOpWithNewOp<tosa::SliceOp>(sliceOp, sliceOp.getType(),
+ newPadOp.getResult(), newStartOp,
+ sliceOp.getSize());
+
+ return success();
+ }
+};
+
// Update size operand of tosa.slice if size has dynamic dims but corresponding
// output dim is static
struct SliceDynamicSizeCanonicalization
@@ -779,8 +900,8 @@ struct SliceDynamicSizeCanonicalization
void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ConcatSliceOptimization, SliceDynamicSizeCanonicalization>(
- context);
+ results.add<ConcatSliceOptimization, PadSliceOptimization,
+ SliceDynamicSizeCanonicalization>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 59fd490330691..6e99f57341982 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -985,6 +985,42 @@ func.func @canonicalize_concat_slice_on_non_concat_axis(%arg0 : tensor<1x12x12xf
// -----
+// CHECK-LABEL: @canonicalize_pad_slice_overlap
+// CHECK-DAG: %[[PAD_CONST:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+// CHECK-DAG: %[[ZERO:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK-DAG: %[[PADDING:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 0, 1, 1, 0, 0]> : tensor<8xindex>}
+// CHECK-DAG: %[[SLICE_SIZE:.*]] = tosa.const_shape {values = dense<[1, 14, 18, 3]> : tensor<4xindex>}
+// CHECK: %[[PADDED:.*]] = tosa.pad %arg0, %[[PADDING]], %[[PAD_CONST]]
+// CHECK: %[[SLICED:.*]] = tosa.slice %[[PADDED]], %[[ZERO]], %[[SLICE_SIZE]]
+func.func @canonicalize_pad_slice_overlap(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x14x18x3xf32> {
+ %pad_const = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %padding = tosa.const_shape {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
+ %padded = tosa.pad %arg0, %padding, %pad_const : (tensor<1x16x16x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x16x20x3xf32>
+ %start = tosa.const_shape {values = dense<[0, 0, 1, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %size = tosa.const_shape {values = dense<[1, 14, 18, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %sliced = tosa.slice %padded, %start, %size : (tensor<1x16x20x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x14x18x3xf32>
+ return %sliced : tensor<1x14x18x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @canonicalize_pad_slice_inside
+// CHECK-DAG: %[[SLICE_START:.*]] = tosa.const_shape {values = dense<[0, 1, 2, 0]> : tensor<4xindex>}
+// CHECK-DAG: %[[SLICE_SIZE:.*]] = tosa.const_shape {values = dense<[1, 14, 10, 3]> : tensor<4xindex>}
+// CHECK-NOT: tosa.pad
+// CHECK: %[[SLICED:.*]] = tosa.slice %arg0, %[[SLICE_START]], %[[SLICE_SIZE]]
+func.func @canonicalize_pad_slice_inside(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x14x14x3xf32> {
+ %pad_const = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %padding = tosa.const_shape {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
+ %padded = tosa.pad %arg0, %padding, %pad_const : (tensor<1x16x16x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x16x20x3xf32>
+ %start = tosa.const_shape {values = dense<[0, 1, 4, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %size = tosa.const_shape {values = dense<[1, 14, 10, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %sliced = tosa.slice %padded, %start, %size : (tensor<1x16x20x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x14x14x3xf32>
+ return %sliced : tensor<1x14x14x3xf32>
+}
+
+// -----
+
// CHECK-LABEL: @fold_log_exp
func.func @fold_log_exp(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: return %arg{{.*}} : tensor<?x1xf32>
More information about the Mlir-commits
mailing list