[Mlir-commits] [mlir] [tosa] : Enhance tosa.slice folding for dynamic dims. (PR #184045)

Sayan Saha llvmlistbot at llvm.org
Sun Mar 1 15:20:12 PST 2026


https://github.com/sahas3 updated https://github.com/llvm/llvm-project/pull/184045

>From ec0ab59622c6c42bbaa01c6734b82d8b2b5e8d22 Mon Sep 17 00:00:00 2001
From: Sayan Saha <sayans at mathworks.com>
Date: Sat, 28 Feb 2026 19:16:17 -0500
Subject: [PATCH 1/3] [tosa] : Enhance tosa.slice folding for dynamic dims.

---
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 54 +++++++++++++++++--
 mlir/test/Dialect/Tosa/canonicalize.mlir      | 33 ++++++++++++
 2 files changed, 82 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 7a1dbcd3e84c7..2a6976cb8b818 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1771,6 +1771,54 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
   if (inputTy == outputTy && inputTy.hasStaticShape())
     return getInput1();
 
+  // Check if this is a no-op slice (starts at 0 and size matches input)
+
+  DenseElementsAttr startElems;
+  if (!matchPattern(getStart(), m_Constant(&startElems)))
+    return {};
+
+  // Check if all start values are zero
+  bool startIsZeros =
+      llvm::all_of(startElems.getValues<APInt>(),
+                   [](const APInt &val) { return val.isZero(); });
+
+  if (startIsZeros) {
+
+    // Check if size matches input shape
+    DenseElementsAttr sizeElems;
+    if (!matchPattern(getSize(), m_Constant(&sizeElems)))
+      return {};
+
+    auto inputShape = inputTy.getShape();
+    auto sizeValues = sizeElems.getValues<APInt>();
+
+    if (sizeValues.size() != inputShape.size())
+      return {};
+
+    bool sizeMatchesInput = true;
+    for (const auto &[i, sizeVal] : llvm::enumerate(sizeValues)) {
+      int64_t size = sizeVal.getSExtValue();
+
+      if (inputTy.isDynamicDim(i)) {
+        // For dynamic dimensions, check for -1 indicating full dimension is sliced
+        if (size != -1) {
+          sizeMatchesInput = false;
+          break;
+        }
+      } else {
+        // For static dimensions, check that size must match exactly or be -1 indicating full dimension is sliced
+        if (size != -1 && size != inputShape[i]) {
+          sizeMatchesInput = false;
+          break;
+        }
+      }
+    }
+
+    if (sizeMatchesInput)
+      return getInput1();
+  }
+
+  // The following checks require the input to be a constant
   if (!adaptor.getInput1())
     return {};
 
@@ -1785,11 +1833,7 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
   }
 
   if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
-      outputTy.getNumElements() == 1) {
-    DenseElementsAttr startElems;
-    if (!matchPattern(getStart(), m_Constant(&startElems)))
-      return {};
-
+      outputTy.getNumElements() == 1) {    
     llvm::SmallVector<uint64_t> indices =
         llvm::to_vector(startElems.getValues<uint64_t>());
     auto value = operand.getValues<Attribute>()[indices];
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 1ade9793048de..52098413f18d9 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -784,6 +784,39 @@ func.func @slice_nofold(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
   %3 = tosa.slice %arg0, %0, %1 : (tensor<?x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x4xf32>
   return %3 : tensor<?x4xf32>
 }
+// -----
+
+// CHECK-LABEL: @slice_fold_dynamic
+func.func @slice_fold_dynamic(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
+  %0 = tosa.const_shape {values = dense<[0, 0]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %1 = tosa.const_shape {values = dense<[-1, 4]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK: return %arg0
+  %3 = tosa.slice %arg0, %0, %1 : (tensor<?x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x4xf32>
+  return %3 : tensor<?x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @slice_fold_static_dynamic
+func.func @slice_fold_static_dynamic(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
+  %0 = tosa.const_shape {values = dense<[0, 0]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %1 = tosa.const_shape {values = dense<[-1, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK: return %arg0
+  %3 = tosa.slice %arg0, %0, %1 : (tensor<?x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x4xf32>
+  return %3 : tensor<?x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @slice_nofold_static
+func.func @slice_nofold_static(%arg0: tensor<3x4xf32>) -> tensor<3x2xf32> {
+  %0 = tosa.const_shape {values = dense<[0, 0]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %1 = tosa.const_shape {values = dense<[3, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK: tosa.slice
+  %3 = tosa.slice %arg0, %0, %1 : (tensor<3x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x2xf32>
+  return %3 : tensor<3x2xf32>
+}
+
 
 // -----
 

>From 413d4769c170c062ae2ed362829c9dfde0f97707 Mon Sep 17 00:00:00 2001
From: Sayan Saha <sayans at mathworks.com>
Date: Sun, 1 Mar 2026 17:38:17 -0500
Subject: [PATCH 2/3] Fix formatting.

---
 mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 2a6976cb8b818..41b37052aef1b 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1800,13 +1800,15 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
       int64_t size = sizeVal.getSExtValue();
 
       if (inputTy.isDynamicDim(i)) {
-        // For dynamic dimensions, check for -1 indicating full dimension is sliced
+        // For dynamic dimensions, check for -1 indicating full dimension is
+        // sliced
         if (size != -1) {
           sizeMatchesInput = false;
           break;
         }
       } else {
-        // For static dimensions, check that size must match exactly or be -1 indicating full dimension is sliced
+        // For static dimensions, check that size must match exactly or be -1
+        // indicating full dimension is sliced
         if (size != -1 && size != inputShape[i]) {
           sizeMatchesInput = false;
           break;
@@ -1833,7 +1835,7 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
   }
 
   if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
-      outputTy.getNumElements() == 1) {    
+      outputTy.getNumElements() == 1) {
     llvm::SmallVector<uint64_t> indices =
         llvm::to_vector(startElems.getValues<uint64_t>());
     auto value = operand.getValues<Attribute>()[indices];

>From dd64b94d2053a0ea5d9a180d7ca82299b6899712 Mon Sep 17 00:00:00 2001
From: Sayan Saha <sayans at mathworks.com>
Date: Sun, 1 Mar 2026 18:19:55 -0500
Subject: [PATCH 3/3] Add a nofold test

---
 mlir/test/Dialect/Tosa/canonicalize.mlir | 11 +++++++++++
 1 file changed, 11 insertions(+)

diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 52098413f18d9..bb3af8338043a 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -797,6 +797,17 @@ func.func @slice_fold_dynamic(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
 
 // -----
 
+// CHECK-LABEL: @slice_nofold_dynamic
+func.func @slice_nofold_dynamic(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
+  %0 = tosa.const_shape {values = dense<[1, 0]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %1 = tosa.const_shape {values = dense<[-1, 4]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK: tosa.slice
+  %3 = tosa.slice %arg0, %0, %1 : (tensor<?x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x4xf32>
+  return %3 : tensor<?x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @slice_fold_static_dynamic
 func.func @slice_fold_static_dynamic(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
   %0 = tosa.const_shape {values = dense<[0, 0]> : tensor<2xindex>} : () -> !tosa.shape<2>



More information about the Mlir-commits mailing list