[Mlir-commits] [mlir] [tosa] : Enhance tosa.slice folding for dynamic dims. (PR #184615)
Sayan Saha
llvmlistbot at llvm.org
Wed Mar 4 06:18:04 PST 2026
https://github.com/sahas3 updated https://github.com/llvm/llvm-project/pull/184615
>From bee6456b7f49a08d24ac526682c33755314b3bc1 Mon Sep 17 00:00:00 2001
From: Sayan Saha <sayans at mathworks.com>
Date: Wed, 4 Mar 2026 09:13:53 -0500
Subject: [PATCH 1/2] [tosa] : Enhance tosa.slice folding for dynamic dims.
---
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 57 ++++++++++++++++---
mlir/test/Dialect/Tosa/canonicalize.mlir | 33 +++++++++++
2 files changed, 83 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 7a1dbcd3e84c7..571bd684af4c2 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -754,7 +754,7 @@ struct PadSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](int64_t i) {
const bool isDimDynamic = inputTy.isDynamicDim(i);
const bool isDimSliced =
- (sliceStarts[i] != 0) || (sliceSizes[i] != -1);
+ (sliceStarts[i] != 0) || (sliceSizes[i] != kInferableDimSize);
return isDimDynamic && isDimSliced;
})) {
@@ -854,11 +854,11 @@ struct SliceDynamicSizeCanonicalization
llvm::to_vector(sizeElems.getValues<int64_t>());
bool replaceSliceSize{false};
- // if size op has -1 indicating dynamic shape but corresponding dim on the
+ // if size op has kInferableDimSize indicating dynamic shape but corresponding dim on the
// output is statically known, update size to match with known output dim
// shape
for (const auto &[index, size] : llvm::enumerate(sliceSizes)) {
- if (size == -1 && !resultType.isDynamicDim(index)) {
+ if (size == kInferableDimSize && !resultType.isDynamicDim(index)) {
sliceSizes[index] = resultType.getDimSize(index);
replaceSliceSize = true;
}
@@ -1771,6 +1771,53 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
if (inputTy == outputTy && inputTy.hasStaticShape())
return getInput1();
+ // Check if this is a no-op slice (starts at 0 and size matches input)
+
+ DenseElementsAttr startElems;
+ if (!matchPattern(getStart(), m_Constant(&startElems)))
+ return {};
+
+ // Check if all start values are zero
+ bool startIsZeros =
+ llvm::all_of(startElems.getValues<APInt>(),
+ [](const APInt &val) { return val.isZero(); });
+
+ if (startIsZeros) {
+
+ // Check if size matches input shape
+ DenseElementsAttr sizeElems;
+ if (!matchPattern(getSize(), m_Constant(&sizeElems)))
+ return {};
+
+ auto inputShape = inputTy.getShape();
+ auto sizeValues = sizeElems.getValues<APInt>();
+
+ bool sizeMatchesInput = true;
+ for (const auto &[i, sizeVal] : llvm::enumerate(sizeValues)) {
+ int64_t size = sizeVal.getSExtValue();
+
+ if (inputTy.isDynamicDim(i)) {
+ // For dynamic dimensions, check for kInferableDimSize indicating full dimension is
+ // sliced
+ if (size != kInferableDimSize) {
+ sizeMatchesInput = false;
+ break;
+ }
+ } else {
+ // For static dimensions, check that size must match exactly or be kInferableDimSize
+ // indicating full dimension is sliced
+ if (size != kInferableDimSize && size != inputShape[i]) {
+ sizeMatchesInput = false;
+ break;
+ }
+ }
+ }
+
+ if (sizeMatchesInput)
+ return getInput1();
+ }
+
+ // The following checks require the input to be a constant
if (!adaptor.getInput1())
return {};
@@ -1786,10 +1833,6 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
outputTy.getNumElements() == 1) {
- DenseElementsAttr startElems;
- if (!matchPattern(getStart(), m_Constant(&startElems)))
- return {};
-
llvm::SmallVector<uint64_t> indices =
llvm::to_vector(startElems.getValues<uint64_t>());
auto value = operand.getValues<Attribute>()[indices];
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 1ade9793048de..52098413f18d9 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -784,6 +784,39 @@ func.func @slice_nofold(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
%3 = tosa.slice %arg0, %0, %1 : (tensor<?x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x4xf32>
return %3 : tensor<?x4xf32>
}
+// -----
+
+// CHECK-LABEL: @slice_fold_dynamic
+func.func @slice_fold_dynamic(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
+ %0 = tosa.const_shape {values = dense<[0, 0]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %1 = tosa.const_shape {values = dense<[-1, 4]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // CHECK: return %arg0
+ %3 = tosa.slice %arg0, %0, %1 : (tensor<?x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x4xf32>
+ return %3 : tensor<?x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @slice_fold_static_dynamic
+func.func @slice_fold_static_dynamic(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
+ %0 = tosa.const_shape {values = dense<[0, 0]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %1 = tosa.const_shape {values = dense<[-1, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // CHECK: return %arg0
+ %3 = tosa.slice %arg0, %0, %1 : (tensor<?x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x4xf32>
+ return %3 : tensor<?x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @slice_nofold_static
+func.func @slice_nofold_static(%arg0: tensor<3x4xf32>) -> tensor<3x2xf32> {
+ %0 = tosa.const_shape {values = dense<[0, 0]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %1 = tosa.const_shape {values = dense<[3, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ // CHECK: tosa.slice
+ %3 = tosa.slice %arg0, %0, %1 : (tensor<3x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x2xf32>
+ return %3 : tensor<3x2xf32>
+}
+
// -----
>From 3ed30d66c30f4b300c564dd8dffd978b49ef53fb Mon Sep 17 00:00:00 2001
From: Sayan Saha <sayans at mathworks.com>
Date: Wed, 4 Mar 2026 09:17:49 -0500
Subject: [PATCH 2/2] Fix formatting.
---
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 14 +++++++-------
1 file changed, 7 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 571bd684af4c2..0dbe4e43a43a0 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -854,9 +854,9 @@ struct SliceDynamicSizeCanonicalization
llvm::to_vector(sizeElems.getValues<int64_t>());
bool replaceSliceSize{false};
- // if size op has kInferableDimSize indicating dynamic shape but corresponding dim on the
- // output is statically known, update size to match with known output dim
- // shape
+ // if size op has kInferableDimSize indicating dynamic shape but
+ // corresponding dim on the output is statically known, update size to match
+ // with known output dim shape
for (const auto &[index, size] : llvm::enumerate(sliceSizes)) {
if (size == kInferableDimSize && !resultType.isDynamicDim(index)) {
sliceSizes[index] = resultType.getDimSize(index);
@@ -1797,15 +1797,15 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
int64_t size = sizeVal.getSExtValue();
if (inputTy.isDynamicDim(i)) {
- // For dynamic dimensions, check for kInferableDimSize indicating full dimension is
- // sliced
+ // For dynamic dimensions, check for kInferableDimSize indicating full
+ // dimension is sliced
if (size != kInferableDimSize) {
sizeMatchesInput = false;
break;
}
} else {
- // For static dimensions, check that size must match exactly or be kInferableDimSize
- // indicating full dimension is sliced
+ // For static dimensions, check that size must match exactly or be
+ // kInferableDimSize indicating full dimension is sliced
if (size != kInferableDimSize && size != inputShape[i]) {
sizeMatchesInput = false;
break;
More information about the Mlir-commits
mailing list