[Mlir-commits] [mlir] [mlir][tosa]: Add SLICE_SHAPE folder (PR #186997)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 17 04:19:31 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Udaya Ranga (udaya-ranga)

<details>
<summary>Changes</summary>

Add compile time folder for tosa.SLICE_SHAPE

---
Full diff: https://github.com/llvm/llvm-project/pull/186997.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td (+2) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+60) 
- (modified) mlir/test/Dialect/Tosa/constant_folding.mlir (+12) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index f2c48b7684c26..be4d31a4372c8 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -346,6 +346,8 @@ def Tosa_SliceShapeOp : Tosa_ShapeOp<"slice_shape", [Pure]> {
   let results = (outs Tosa_Shape:$output);
 
   let hasVerifier = 1;
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index b5e067f76a979..ef225dba9dc3e 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1024,6 +1024,16 @@ static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy,
   return {};
 }
 
+static FailureOr<int64_t> getSingleI64From1ElementTensor(Value v) {
+  DenseIntElementsAttr dense{};
+  if (!matchPattern(v, m_Constant(&dense)))
+    return failure();
+
+  assert(dense.isSplat());
+  APInt a = dense.getSplatValue<APInt>();
+  return a.getSExtValue();
+}
+
 struct AddFoldAdaptor {
   static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
                                const bool isUnsigned) {
@@ -2093,6 +2103,52 @@ OpFoldResult concatShapeFold(tosa::ConcatShapeOp *op) {
   return DenseElementsAttr::get(rankedTy, concatDims);
 }
 
+OpFoldResult sliceShapeFold(tosa::SliceShapeOp *op) {
+  auto const input1 = op->getInput();
+  auto const input2 = op->getStart();
+  auto const input3 = op->getSize();
+
+  auto input1ConstShape = dyn_cast<tosa::ConstShapeOp>(input1.getDefiningOp());
+
+  if (!input1ConstShape)
+    return {};
+
+  auto const input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
+  if (!input1Attr)
+    return {};
+
+  auto const input1Vals = input1Attr.getValues<APInt>();
+  auto const totalInput1 = input1Vals.size();
+
+  auto const start = getSingleI64From1ElementTensor(input2);
+  auto const size = getSingleI64From1ElementTensor(input3);
+
+  if (failed(start) || failed(size))
+    return {};
+
+  auto const startV = static_cast<int32_t>(start.value());
+  auto const sizeV = static_cast<int32_t>(size.value());
+
+  if ((sizeV <= 0) || (startV < 0) ||
+      (static_cast<size_t>(startV + sizeV) > totalInput1))
+    return {};
+
+  SmallVector<APInt> sliceOfInput;
+  sliceOfInput.reserve(totalInput1);
+
+  for (auto i = startV; i < (startV + sizeV); i++) {
+    sliceOfInput.push_back(input1Vals[i]);
+  }
+
+  auto *ctx = op->getContext();
+  assert(ctx != nullptr && "ctx is nullptr");
+
+  auto const rankedTy = RankedTensorType::get(
+      {static_cast<int64_t>(sliceOfInput.size())}, IndexType::get(ctx));
+
+  return DenseElementsAttr::get(rankedTy, sliceOfInput);
+}
+
 OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) {
   return binaryFold<AddShapeOp, AddFoldAdaptor>(this);
 }
@@ -2140,3 +2196,7 @@ OpFoldResult tosa::Log2FloorShapeOp::fold(FoldAdaptor adaptor) {
 OpFoldResult tosa::ConcatShapeOp::fold(FoldAdaptor adaptor) {
   return concatShapeFold(this);
 }
+
+OpFoldResult tosa::SliceShapeOp::fold(FoldAdaptor adaptor) {
+  return sliceShapeFold(this);
+}
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index dd35f777bebb2..9eeb5d13317af 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -1469,4 +1469,16 @@ func.func @test_concat_shape_total_rank9_shapes() -> !tosa.shape<9> {
 
   return %abc : !tosa.shape<9>
 }
+
+// -----
+
+// CHECK-LABEL: @test_slice_shape
+// CHECK: tosa.const_shape  {values = dense<[3, 4, 5, 6]> : tensor<4xindex>} : () -> !tosa.shape<4>
+func.func @test_slice_shape() -> !tosa.shape<4> {
+  %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %b = "tosa.const"() {values = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+  %c = "tosa.const"() {values = dense<4> : tensor<1xi32>} : () -> tensor<1xi32>
+  %d = tosa.slice_shape %a, %b, %c : (!tosa.shape<6>, tensor<1xi32>, tensor<1xi32>) -> !tosa.shape<4>
+  return %d : !tosa.shape<4>
+}
 // -----

``````````

</details>


https://github.com/llvm/llvm-project/pull/186997


More information about the Mlir-commits mailing list