[Mlir-commits] [mlir] [tosa]: canonicalize dynamic size of tosa.slice to static output shape (PR #135429)

Sayan Saha llvmlistbot at llvm.org
Fri Apr 11 12:56:16 PDT 2025


https://github.com/sahas3 created https://github.com/llvm/llvm-project/pull/135429

Addresses https://github.com/llvm/llvm-project/issues/135389

>From 43099d90c2d27a1642c2cc23184dab934ac895d6 Mon Sep 17 00:00:00 2001
From: Sayan Saha <sayans at mathworks.com>
Date: Fri, 11 Apr 2025 15:53:33 -0400
Subject: [PATCH] [tosa]: canonicalize dynamic size of tosa.slice to static
 output shape

---
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 54 ++++++++++++++++++-
 mlir/test/Dialect/Tosa/canonicalize.mlir      | 15 ++++++
 2 files changed, 68 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index c4ef7d0bb9ff5..67d8baf32539f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -731,9 +731,61 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
   }
 };
 
+// Update size operand of tosa.slice if size has dynamic dims but corresponding
+// output dim is static
+struct SliceDynamicSizeCanonicalization : public OpRewritePattern<tosa::SliceOp> {
+  using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
+                                PatternRewriter &rewriter) const override {
+      ShapedType resultType = cast<ShapedType>(sliceOp.getType());
+
+      ElementsAttr sizeElems;
+      if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) {
+        return rewriter.notifyMatchFailure(
+            sliceOp, "size of slice must be a static ranked shape");
+      }
+
+      llvm::SmallVector<int64_t> sliceSizes =
+          llvm::to_vector(sizeElems.getValues<int64_t>());
+
+      bool replaceSliceSize{false};
+      // if size op has -1 indicating dynamic shape but corresponding dim on the
+      // output is statically known, update size to match with known output dim shape
+      for (const auto i : llvm::enumerate(sliceSizes)) {
+        int64_t size = i.value();
+        size_t index = i.index();
+        if (size == -1 && !resultType.isDynamicDim(index)) {
+          sliceSizes[index] = resultType.getDimSize(index);
+          replaceSliceSize = true;
+        }
+      }
+
+      if (!replaceSliceSize) {
+        return rewriter.notifyMatchFailure(
+            sliceOp, "no dimension of size of slice is dynamic that resolves "
+                     "to static output shape");
+      }
+      
+      auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
+      auto newSliceOp = rewriter.create<tosa::SliceOp>(
+          sliceOp.getLoc(), sliceOp.getType(), sliceOp.getInput1(),
+          sliceOp.getStart(), size_op);
+
+      rewriter.replaceOp(sliceOp, newSliceOp.getResult());
+
+      // Remove const_shape size op when it no longer has use point.
+      Operation *sizeConstShape = sliceOp.getSize().getDefiningOp();
+      if (sizeConstShape->getResult(0).hasOneUse())
+        rewriter.eraseOp(sizeConstShape);
+
+      return success();
+  }
+};
+
 void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                           MLIRContext *context) {
-  results.add<ConcatSliceOptimization>(context);
+  results.add<ConcatSliceOptimization, SliceDynamicSizeCanonicalization>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index b366b4f1e4fd4..a754a46be603f 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1212,3 +1212,18 @@ func.func @do_not_fold_intdiv_division_by_0() -> tensor<1x24x2xi32> {
   %16 = tosa.intdiv %4, %1 : (tensor<1x24x2xi32>, tensor<1x24x2xi32>) -> tensor<1x24x2xi32>
   return %16 : tensor<1x24x2xi32>
 }
+
+
+// ----
+// CHECK-LABEL:   func.func @slice_dynamic_size_static_output_canonicalize(
+// CHECK-SAME:                     %[[ARG0:.*]]: tensor<2x60x59x?xf32>) -> tensor<2x60x58x?xf32> {
+// CHECK:           %[[START:.*]] = tosa.const_shape  {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK:           %[[SIZE:.*]] = tosa.const_shape  {values = dense<[2, 60, 58, -1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK:           %[[SLICE:.*]] = tosa.slice %[[ARG0]], %[[START]], %[[SIZE]] : (tensor<2x60x59x?xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<2x60x58x?xf32>
+// CHECK:           return %[[SLICE]]
+func.func @slice_dynamic_size_static_output_canonicalize(%arg0: tensor<2x60x59x?xf32>) -> tensor<2x60x58x?xf32> {
+    %0 = tosa.const_shape  {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
+    %1 = tosa.const_shape  {values = dense<[-1, 60, 58, -1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+    %2 = tosa.slice %arg0, %0, %1 : (tensor<2x60x59x?xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<2x60x58x?xf32>
+    return %2 : tensor<2x60x58x?xf32>
+  }



More information about the Mlir-commits mailing list