[Mlir-commits] [mlir] [mlir][tosa] Canonicalise slice over overlapped or inside a pad. (PR #138270)

Georgios Pinitas llvmlistbot at llvm.org
Fri May 2 06:15:33 PDT 2025


https://github.com/GeorgeARM created https://github.com/llvm/llvm-project/pull/138270

Update the paddings and/or the slice parameters when a `tosa.slice` after a `tosa.pad` is accessing only an overlapping or not region of the padded tensor.

>From 64d81c8386d7463f38fdcc42e70f6a62bfbf1567 Mon Sep 17 00:00:00 2001
From: Georgios Pinitas <georgios.pinitas at arm.com>
Date: Thu, 1 May 2025 16:02:58 +0100
Subject: [PATCH] [mlir][tosa] Canonicalize slice over overlapped or inside a
 pad.

Update the paddings and/or the slice parameters when a `tosa.slice` after a
`tosa.pad` is accessing only an overlapping or not region of the padded
tensor.

Signed-off-by: Georgios Pinitas <georgios.pinitas at arm.com>
---
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 125 +++++++++++++++++-
 mlir/test/Dialect/Tosa/canonicalize.mlir      |  36 +++++
 2 files changed, 159 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 47368532df169..5347fb1c16698 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -731,6 +731,127 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
   }
 };
 
+struct PadSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
+  using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
+                                PatternRewriter &rewriter) const override {
+    Value sliceInput = sliceOp.getInput1();
+
+    // Check if producer is a PadOp
+    auto padOp = sliceInput.getDefiningOp<tosa::PadOp>();
+    if (!padOp)
+      return rewriter.notifyMatchFailure(sliceOp,
+                                         "slice input must be a pad operation");
+
+    // Check PadOp has a single consumer
+    if (!padOp->hasOneUse())
+      return rewriter.notifyMatchFailure(sliceOp,
+                                         "pad shall have a single consumer");
+
+    // Check input is statically ranked
+    auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
+    auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
+    if (!inputTy || !padTy)
+      return rewriter.notifyMatchFailure(
+          sliceOp, "slice input must be a static ranked tensor");
+
+    // Validate and extract tosa::PadOp padding
+    DenseIntElementsAttr paddingElems;
+    if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
+      return rewriter.notifyMatchFailure(
+          sliceOp,
+          "The `padding` input specified on the tosa::PadOp must be constant.");
+    }
+    llvm::SmallVector<int64_t> padPaddings =
+        llvm::to_vector(paddingElems.getValues<int64_t>());
+
+    // Extract slice parameters
+    DenseElementsAttr startElems;
+    if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
+      return rewriter.notifyMatchFailure(
+          sliceOp, "start of slice must be a static ranked shape");
+    llvm::SmallVector<int64_t> sliceStarts =
+        llvm::to_vector(startElems.getValues<int64_t>());
+
+    DenseElementsAttr sizeElems;
+    if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
+      return rewriter.notifyMatchFailure(
+          sliceOp, "size of slice must be a static ranked shape");
+    llvm::SmallVector<int64_t> sliceSizes =
+        llvm::to_vector(sizeElems.getValues<int64_t>());
+
+    // Update the paddings
+    int64_t rank = inputTy.getRank();
+    llvm::SmallVector<int64_t> newSliceStarts(rank, 0);
+    llvm::SmallVector<int64_t> newPadPaddings(2 * rank, 0);
+    llvm::SmallVector<int64_t> newPadShape(rank, 0);
+    bool updated = false;
+    for (int64_t i = 0; i < rank; ++i) {
+      const int64_t padLo = padPaddings[i * 2];
+      const int64_t padHi = padPaddings[i * 2 + 1];
+      const int64_t sliceStart = sliceStarts[i];
+      const int64_t sliceSize = sliceSizes[i];
+      const int64_t sliceEnd = sliceStart + sliceSize;
+
+      const int64_t dimSize = inputTy.getShape()[i];
+      const int64_t dimStart = padLo;
+      const int64_t dimEnd = padLo + dimSize;
+      const int64_t dimTotal = padLo + dimSize + padHi;
+
+      // Check slice within bounds
+      if (sliceStart < 0 || sliceEnd > dimTotal)
+        return rewriter.notifyMatchFailure(sliceOp, "slice out-of-bounds");
+
+      const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
+      const int64_t newPadHi =
+          std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
+      const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
+
+      // Compute update slice/pad parameters
+      if (sliceStart < dimStart || sliceEnd > dimEnd) {
+        // Handle slice when not within the original input entirely
+        updated |= (newPadLo != padLo) || (newPadHi != padHi) ||
+                   (newSliceStart != sliceStart);
+        newPadPaddings[i * 2] = newPadLo;
+        newPadPaddings[i * 2 + 1] = newPadHi;
+        newSliceStarts[i] = newSliceStart;
+      } else {
+        // Slice is within the original input
+        updated |= newSliceStart != sliceStart;
+        newSliceStarts[i] = newSliceStart;
+      }
+
+      // Calculate new pad output shape
+      newPadShape[i] =
+          newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
+    }
+
+    // Check that we actually need to proceed with the rewrite
+    if (!updated)
+      return rewriter.notifyMatchFailure(
+          sliceOp, "terminate condition; nothing to rewrite");
+
+    // Create a PadOp with updated padding
+    auto newPaddingsOp =
+        getTosaConstShape(rewriter, sliceOp.getLoc(), newPadPaddings);
+    auto newPadTy =
+        RankedTensorType::get(newPadShape, inputTy.getElementType());
+    auto newPadOp = rewriter.create<tosa::PadOp>(
+        padOp.getLoc(), newPadTy, padOp.getInput1(), newPaddingsOp,
+        padOp.getPadConst());
+
+    // Update SliceOp and point to new PadOp
+    auto newStartOp =
+        getTosaConstShape(rewriter, sliceOp.getLoc(), newSliceStarts);
+    rewriter.replaceOpWithNewOp<tosa::SliceOp>(sliceOp, sliceOp.getType(),
+                                               newPadOp.getResult(), newStartOp,
+                                               sliceOp.getSize());
+
+    return success();
+  }
+};
+
 // Update size operand of tosa.slice if size has dynamic dims but corresponding
 // output dim is static
 struct SliceDynamicSizeCanonicalization
@@ -779,8 +900,8 @@ struct SliceDynamicSizeCanonicalization
 
 void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                           MLIRContext *context) {
-  results.add<ConcatSliceOptimization, SliceDynamicSizeCanonicalization>(
-      context);
+  results.add<ConcatSliceOptimization, PadSliceOptimization,
+              SliceDynamicSizeCanonicalization>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 59fd490330691..6e99f57341982 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -985,6 +985,42 @@ func.func @canonicalize_concat_slice_on_non_concat_axis(%arg0 : tensor<1x12x12xf
 
 // -----
 
+// CHECK-LABEL: @canonicalize_pad_slice_overlap
+// CHECK-DAG: %[[PAD_CONST:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+// CHECK-DAG: %[[ZERO:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK-DAG: %[[PADDING:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 0, 1, 1, 0, 0]> : tensor<8xindex>}
+// CHECK-DAG: %[[SLICE_SIZE:.*]] = tosa.const_shape  {values = dense<[1, 14, 18, 3]> : tensor<4xindex>}
+// CHECK: %[[PADDED:.*]] = tosa.pad %arg0, %[[PADDING]], %[[PAD_CONST]]
+// CHECK: %[[SLICED:.*]] = tosa.slice %[[PADDED]], %[[ZERO]], %[[SLICE_SIZE]]
+func.func @canonicalize_pad_slice_overlap(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x14x18x3xf32> {
+  %pad_const = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %padding = tosa.const_shape  {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
+  %padded = tosa.pad %arg0, %padding, %pad_const : (tensor<1x16x16x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x16x20x3xf32>
+  %start = tosa.const_shape  {values = dense<[0, 0, 1, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %size = tosa.const_shape  {values = dense<[1, 14, 18, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %sliced = tosa.slice %padded, %start, %size : (tensor<1x16x20x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x14x18x3xf32>
+  return %sliced : tensor<1x14x18x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @canonicalize_pad_slice_inside
+// CHECK-DAG: %[[SLICE_START:.*]] = tosa.const_shape  {values = dense<[0, 1, 2, 0]> : tensor<4xindex>}
+// CHECK-DAG: %[[SLICE_SIZE:.*]] = tosa.const_shape  {values = dense<[1, 14, 10, 3]> : tensor<4xindex>}
+// CHECK-NOT: tosa.pad
+// CHECK: %[[SLICED:.*]] = tosa.slice %arg0, %[[SLICE_START]], %[[SLICE_SIZE]]
+func.func @canonicalize_pad_slice_inside(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x14x14x3xf32> {
+  %pad_const = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %padding = tosa.const_shape  {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
+  %padded = tosa.pad %arg0, %padding, %pad_const : (tensor<1x16x16x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x16x20x3xf32>
+  %start = tosa.const_shape  {values = dense<[0, 1, 4, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %size = tosa.const_shape  {values = dense<[1, 14, 10, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %sliced = tosa.slice %padded, %start, %size : (tensor<1x16x20x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x14x14x3xf32>
+  return %sliced : tensor<1x14x14x3xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @fold_log_exp
 func.func @fold_log_exp(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
   // CHECK: return %arg{{.*}} : tensor<?x1xf32>



More information about the Mlir-commits mailing list