[Mlir-commits] [mlir] [tosa] : Enhance tosa.slice folding for dynamic dims. (PR #184615)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 4 06:16:15 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Sayan Saha (sahas3)

<details>
<summary>Changes</summary>

Source IR:
```
func.func @<!-- -->main(%arg0: tensor<?x112x64x112xf32>) -> tensor<?x113x65x112xf32> {
    %0 = tosa.const_shape  {values = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
    %1 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
    %2 = tosa.pad %arg0, %0, %1 : (tensor<?x112x64x112xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<?x114x66x112xf32>
    %3 = tosa.const_shape  {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
    %4 = tosa.const_shape  {values = dense<[-1, 113, 65, 112]> : tensor<4xindex>} : () -> !tosa.shape<4>
    %5 = tosa.slice %2, %3, %4 : (tensor<?x114x66x112xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x113x65x112xf32>
    return %5 : tensor<?x113x65x112xf32>
  }
```

when canonicalized produces

```
$> mlir-opt --canonicalize

func.func @<!-- -->main(%arg0: tensor<?x112x64x112xf32>) -> tensor<?x113x65x112xf32> {
    %0 = tosa.const_shape  {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
    %1 = tosa.const_shape  {values = dense<[-1, 113, 65, 112]> : tensor<4xindex>} : () -> !tosa.shape<4>
    %2 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
    %3 = tosa.const_shape  {values = dense<[0, 0, 1, 0, 1, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
    %4 = tosa.pad %arg0, %3, %2 : (tensor<?x112x64x112xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<?x113x65x112xf32>
    %5 = tosa.slice %4, %0, %1 : (tensor<?x113x65x112xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x113x65x112xf32>
    return %5 : tensor<?x113x65x112xf32>
  }
```

because of the `PadSliceOptimization`. Note that the `tosa.slice` op after the optimization is essentially a no-op. This change, enhances the folder to fold such `tosa.slice` ops. After this change canonicalization produces 

```
func.func @<!-- -->main(%arg0: tensor<?x112x64x112xf32>) -> tensor<?x113x65x112xf32> {
    %0 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
    %1 = tosa.const_shape  {values = dense<[0, 0, 1, 0, 1, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
    %2 = tosa.pad %arg0, %1, %0 : (tensor<?x112x64x112xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<?x113x65x112xf32>
    return %2 : tensor<?x113x65x112xf32>
  }
```

---
Full diff: https://github.com/llvm/llvm-project/pull/184615.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+50-7) 
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+33) 


``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 7a1dbcd3e84c7..571bd684af4c2 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -754,7 +754,7 @@ struct PadSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
     if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](int64_t i) {
           const bool isDimDynamic = inputTy.isDynamicDim(i);
           const bool isDimSliced =
-              (sliceStarts[i] != 0) || (sliceSizes[i] != -1);
+              (sliceStarts[i] != 0) || (sliceSizes[i] != kInferableDimSize);
 
           return isDimDynamic && isDimSliced;
         })) {
@@ -854,11 +854,11 @@ struct SliceDynamicSizeCanonicalization
         llvm::to_vector(sizeElems.getValues<int64_t>());
 
     bool replaceSliceSize{false};
-    // if size op has -1 indicating dynamic shape but corresponding dim on the
+    // if size op has kInferableDimSize indicating dynamic shape but corresponding dim on the
     // output is statically known, update size to match with known output dim
     // shape
     for (const auto &[index, size] : llvm::enumerate(sliceSizes)) {
-      if (size == -1 && !resultType.isDynamicDim(index)) {
+      if (size == kInferableDimSize && !resultType.isDynamicDim(index)) {
         sliceSizes[index] = resultType.getDimSize(index);
         replaceSliceSize = true;
       }
@@ -1771,6 +1771,53 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
   if (inputTy == outputTy && inputTy.hasStaticShape())
     return getInput1();
 
+  // Check if this is a no-op slice (starts at 0 and size matches input)
+
+  DenseElementsAttr startElems;
+  if (!matchPattern(getStart(), m_Constant(&startElems)))
+    return {};
+
+  // Check if all start values are zero
+  bool startIsZeros =
+      llvm::all_of(startElems.getValues<APInt>(),
+                   [](const APInt &val) { return val.isZero(); });
+
+  if (startIsZeros) {
+
+    // Check if size matches input shape
+    DenseElementsAttr sizeElems;
+    if (!matchPattern(getSize(), m_Constant(&sizeElems)))
+      return {};
+
+    auto inputShape = inputTy.getShape();
+    auto sizeValues = sizeElems.getValues<APInt>();
+
+    bool sizeMatchesInput = true;
+    for (const auto &[i, sizeVal] : llvm::enumerate(sizeValues)) {
+      int64_t size = sizeVal.getSExtValue();
+
+      if (inputTy.isDynamicDim(i)) {
+        // For dynamic dimensions, check for kInferableDimSize indicating full dimension is
+        // sliced
+        if (size != kInferableDimSize) {
+          sizeMatchesInput = false;
+          break;
+        }
+      } else {
+        // For static dimensions, check that size must match exactly or be kInferableDimSize
+        // indicating full dimension is sliced
+        if (size != kInferableDimSize && size != inputShape[i]) {
+          sizeMatchesInput = false;
+          break;
+        }
+      }
+    }
+
+    if (sizeMatchesInput)
+      return getInput1();
+  }
+
+  // The following checks require the input to be a constant
   if (!adaptor.getInput1())
     return {};
 
@@ -1786,10 +1833,6 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
 
   if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
       outputTy.getNumElements() == 1) {
-    DenseElementsAttr startElems;
-    if (!matchPattern(getStart(), m_Constant(&startElems)))
-      return {};
-
     llvm::SmallVector<uint64_t> indices =
         llvm::to_vector(startElems.getValues<uint64_t>());
     auto value = operand.getValues<Attribute>()[indices];
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 1ade9793048de..52098413f18d9 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -784,6 +784,39 @@ func.func @slice_nofold(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
   %3 = tosa.slice %arg0, %0, %1 : (tensor<?x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x4xf32>
   return %3 : tensor<?x4xf32>
 }
+// -----
+
+// CHECK-LABEL: @slice_fold_dynamic
+func.func @slice_fold_dynamic(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
+  %0 = tosa.const_shape {values = dense<[0, 0]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %1 = tosa.const_shape {values = dense<[-1, 4]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK: return %arg0
+  %3 = tosa.slice %arg0, %0, %1 : (tensor<?x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x4xf32>
+  return %3 : tensor<?x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @slice_fold_static_dynamic
+func.func @slice_fold_static_dynamic(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
+  %0 = tosa.const_shape {values = dense<[0, 0]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %1 = tosa.const_shape {values = dense<[-1, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK: return %arg0
+  %3 = tosa.slice %arg0, %0, %1 : (tensor<?x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x4xf32>
+  return %3 : tensor<?x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @slice_nofold_static
+func.func @slice_nofold_static(%arg0: tensor<3x4xf32>) -> tensor<3x2xf32> {
+  %0 = tosa.const_shape {values = dense<[0, 0]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %1 = tosa.const_shape {values = dense<[3, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  // CHECK: tosa.slice
+  %3 = tosa.slice %arg0, %0, %1 : (tensor<3x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x2xf32>
+  return %3 : tensor<3x2xf32>
+}
+
 
 // -----
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/184615


More information about the Mlir-commits mailing list