[Mlir-commits] [mlir] 2a88feb - [mlir][tosa] Canonicalize slice over overlapped or inside a pad. (#138900)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 8 02:50:36 PDT 2025
Author: Georgios Pinitas
Date: 2025-05-08T10:50:33+01:00
New Revision: 2a88feb3947606679453f886d79db611cdaef9fc
URL: https://github.com/llvm/llvm-project/commit/2a88feb3947606679453f886d79db611cdaef9fc
DIFF: https://github.com/llvm/llvm-project/commit/2a88feb3947606679453f886d79db611cdaef9fc.diff
LOG: [mlir][tosa] Canonicalize slice over overlapped or inside a pad. (#138900)
Update the paddings and/or the slice parameters when a `tosa.slice`
after a `tosa.pad` is accessing only an overlapping or not region of the
padded tensor.
Signed-off-by: Georgios Pinitas <georgios.pinitas at arm.com>
Added:
Modified:
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
mlir/test/Dialect/Tosa/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index e73e2c4e33522..92b620473d2a0 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -731,6 +731,141 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
}
};
+struct PadSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
+ using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
+ PatternRewriter &rewriter) const override {
+ Value sliceInput = sliceOp.getInput1();
+
+ // Check if producer is a PadOp
+ auto padOp = sliceInput.getDefiningOp<tosa::PadOp>();
+ if (!padOp)
+ return rewriter.notifyMatchFailure(sliceOp,
+ "slice input must be a pad operation");
+
+ // Check PadOp has a single consumer
+ if (!padOp->hasOneUse())
+ return rewriter.notifyMatchFailure(sliceOp,
+ "pad shall have a single consumer");
+
+ // Check input is statically ranked
+ auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
+ auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
+ if (!inputTy || !padTy || !inputTy.hasRank())
+ return rewriter.notifyMatchFailure(sliceOp,
+ "slice input must be a ranked tensor");
+
+ // Validate and extract tosa::PadOp padding
+ DenseIntElementsAttr paddingElems;
+ if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
+ return rewriter.notifyMatchFailure(
+ sliceOp,
+ "`padding` input specified on the tosa::PadOp must be constant.");
+ }
+ llvm::SmallVector<int64_t> padPaddings =
+ llvm::to_vector(paddingElems.getValues<int64_t>());
+
+ // Extract slice parameters
+ DenseElementsAttr startElems;
+ if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
+ return rewriter.notifyMatchFailure(
+ sliceOp, "start of slice must be a static ranked shape");
+ llvm::SmallVector<int64_t> sliceStarts =
+ llvm::to_vector(startElems.getValues<int64_t>());
+
+ DenseElementsAttr sizeElems;
+ if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
+ return rewriter.notifyMatchFailure(
+ sliceOp, "size of slice must be a static ranked shape");
+ llvm::SmallVector<int64_t> sliceSizes =
+ llvm::to_vector(sizeElems.getValues<int64_t>());
+
+ // Check if dynamic dimensions are sliced
+ const int64_t rank = inputTy.getRank();
+ if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](int64_t i) {
+ const bool isDimDynamic = inputTy.isDynamicDim(i);
+ const bool isDimSliced =
+ (sliceStarts[i] != 0) || (sliceSizes[i] != -1);
+
+ return isDimDynamic && isDimSliced;
+ })) {
+ return rewriter.notifyMatchFailure(
+ sliceOp, "axis that are sliced shall be statically known.");
+ }
+
+ // Update the parameters
+ llvm::SmallVector<int64_t> newSliceStarts(rank, 0);
+ llvm::SmallVector<int64_t> newPadPaddings(2 * rank, 0);
+ llvm::SmallVector<int64_t> newPadShape(rank, ShapedType::kDynamic);
+ bool updated = false;
+
+ for (int64_t i = 0; i < rank; ++i) {
+ const int64_t padLo = padPaddings[i * 2];
+ const int64_t padHi = padPaddings[i * 2 + 1];
+ const int64_t sliceStart = sliceStarts[i];
+ const int64_t sliceSize = sliceSizes[i];
+ const int64_t sliceEnd = sliceStart + sliceSize;
+
+ // If dimension is dynamic pass-through
+ if (inputTy.isDynamicDim(i)) {
+ newPadPaddings[i * 2] = padLo;
+ newPadPaddings[i * 2 + 1] = padHi;
+ newSliceStarts[i] = sliceStart;
+ continue;
+ }
+
+ // Handle static dimensions
+ const int64_t dimSize = inputTy.getShape()[i];
+ const int64_t dimTotal = padLo + dimSize + padHi;
+
+ // Check slice within bounds
+ if (sliceStart < 0 || sliceEnd > dimTotal)
+ return rewriter.notifyMatchFailure(sliceOp, "slice is out-of-bounds");
+
+ // Compute updated slice start parameter
+ const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
+ newSliceStarts[i] = newSliceStart;
+ updated |= newSliceStart != sliceStart;
+
+ // Compute updated pad parameters
+ const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
+ const int64_t newPadHi =
+ std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
+ newPadPaddings[i * 2] = newPadLo;
+ newPadPaddings[i * 2 + 1] = newPadHi;
+ updated |= (newPadLo != padLo) || (newPadHi != padHi);
+
+ // Calculate new pad output shape
+ newPadShape[i] =
+ newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
+ }
+
+ // Check that we actually need to proceed with the rewrite
+ if (!updated)
+ return rewriter.notifyMatchFailure(
+ sliceOp, "terminate condition; nothing to rewrite");
+
+ // Create a PadOp with updated padding
+ auto newPaddingsOp =
+ getTosaConstShape(rewriter, sliceOp.getLoc(), newPadPaddings);
+ auto newPadTy =
+ RankedTensorType::get(newPadShape, inputTy.getElementType());
+ auto newPadOp = rewriter.create<tosa::PadOp>(
+ padOp.getLoc(), newPadTy, padOp.getInput1(), newPaddingsOp,
+ padOp.getPadConst());
+
+ // Update SliceOp and point to new PadOp
+ auto newStartOp =
+ getTosaConstShape(rewriter, sliceOp.getLoc(), newSliceStarts);
+ rewriter.replaceOpWithNewOp<tosa::SliceOp>(sliceOp, sliceOp.getType(),
+ newPadOp.getResult(), newStartOp,
+ sliceOp.getSize());
+
+ return success();
+ }
+};
+
// Update size operand of tosa.slice if size has dynamic dims but corresponding
// output dim is static
struct SliceDynamicSizeCanonicalization
@@ -779,8 +914,8 @@ struct SliceDynamicSizeCanonicalization
void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ConcatSliceOptimization, SliceDynamicSizeCanonicalization>(
- context);
+ results.add<ConcatSliceOptimization, PadSliceOptimization,
+ SliceDynamicSizeCanonicalization>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index c98335cdafe65..27280807b0282 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -985,6 +985,78 @@ func.func @canonicalize_concat_slice_on_non_concat_axis(%arg0 : tensor<1x12x12xf
// -----
+// CHECK-LABEL: @canonicalize_pad_slice_overlap
+// CHECK-DAG: %[[PAD_CONST:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+// CHECK-DAG: %[[ZERO:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK-DAG: %[[PADDING:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 0, 1, 1, 0, 0]> : tensor<8xindex>}
+// CHECK-DAG: %[[SLICE_SIZE:.*]] = tosa.const_shape {values = dense<[-1, 14, 18, 3]> : tensor<4xindex>}
+// CHECK: %[[PADDED:.*]] = tosa.pad %arg0, %[[PADDING]], %[[PAD_CONST]]
+// CHECK: %[[SLICED:.*]] = tosa.slice %[[PADDED]], %[[ZERO]], %[[SLICE_SIZE]]
+func.func @canonicalize_pad_slice_overlap(%arg0: tensor<?x16x16x3xf32>) -> tensor<?x14x18x3xf32> {
+ %pad_const = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %padding = tosa.const_shape {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
+ %padded = tosa.pad %arg0, %padding, %pad_const : (tensor<?x16x16x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<?x16x20x3xf32>
+ %start = tosa.const_shape {values = dense<[0, 0, 1, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %size = tosa.const_shape {values = dense<[-1, 14, 18, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %sliced = tosa.slice %padded, %start, %size : (tensor<?x16x20x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x14x18x3xf32>
+ return %sliced : tensor<?x14x18x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @canonicalize_pad_slice_inside
+// CHECK-DAG: %[[SLICE_START:.*]] = tosa.const_shape {values = dense<[0, 1, 2, 0]> : tensor<4xindex>}
+// CHECK-DAG: %[[SLICE_SIZE:.*]] = tosa.const_shape {values = dense<[1, 14, 10, 3]> : tensor<4xindex>}
+// CHECK-NOT: tosa.pad
+// CHECK: %[[SLICED:.*]] = tosa.slice %arg0, %[[SLICE_START]], %[[SLICE_SIZE]]
+func.func @canonicalize_pad_slice_inside(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x14x14x3xf32> {
+ %pad_const = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %padding = tosa.const_shape {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
+ %padded = tosa.pad %arg0, %padding, %pad_const : (tensor<1x16x16x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x16x20x3xf32>
+ %start = tosa.const_shape {values = dense<[0, 1, 4, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %size = tosa.const_shape {values = dense<[1, 14, 10, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %sliced = tosa.slice %padded, %start, %size : (tensor<1x16x20x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x14x14x3xf32>
+ return %sliced : tensor<1x14x14x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @canonicalize_pad_slice_exact
+// CHECK-DAG: %[[PAD_CONST:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+// CHECK-DAG: %[[ZERO:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK-DAG: %[[PADDING:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>}
+// CHECK-DAG: %[[SLICE_SIZE:.*]] = tosa.const_shape {values = dense<[1, 16, 20, 2]> : tensor<4xindex>}
+// CHECK: %[[PADDED:.*]] = tosa.pad %arg0, %[[PADDING]], %[[PAD_CONST]]
+// CHECK: %[[SLICED:.*]] = tosa.slice %[[PADDED]], %[[ZERO]], %[[SLICE_SIZE]]
+func.func @canonicalize_pad_slice_exact(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x20x2xf32> {
+ %pad_const = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %padding = tosa.const_shape {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
+ %padded = tosa.pad %arg0, %padding, %pad_const : (tensor<1x16x16x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x16x20x3xf32>
+ %start = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %size = tosa.const_shape {values = dense<[1, 16, 20, 2]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %sliced = tosa.slice %padded, %start, %size : (tensor<1x16x20x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x16x20x2xf32>
+ return %sliced : tensor<1x16x20x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @canonicalize_pad_slice_dynamic_noupdate
+// CHECK-DAG: tosa.const_shape {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>}
+// CHECK-DAG: tosa.const_shape {values = dense<[1, 16, 15, 2]> : tensor<4xindex>}
+// CHECK: tosa.pad
+// CHECK: tosa.slice
+func.func @canonicalize_pad_slice_dynamic_noupdate(%arg0: tensor<1x16x?x3xf32>) -> tensor<1x16x?x2xf32> {
+ %pad_const = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %padding = tosa.const_shape {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
+ %padded = tosa.pad %arg0, %padding, %pad_const : (tensor<1x16x?x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x16x?x3xf32>
+ %start = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %size = tosa.const_shape {values = dense<[1, 16, 15, 2]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %sliced = tosa.slice %padded, %start, %size : (tensor<1x16x?x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x16x?x2xf32>
+ return %sliced : tensor<1x16x?x2xf32>
+}
+
+// -----
+
// CHECK-LABEL: @fold_log_exp
func.func @fold_log_exp(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: return %arg{{.*}} : tensor<?x1xf32>
More information about the Mlir-commits
mailing list