[Mlir-commits] [mlir] [mlir][tosa]: Add SLICE_SHAPE folder (PR #186997)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 17 04:19:31 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-tosa
Author: Udaya Ranga (udaya-ranga)
<details>
<summary>Changes</summary>
Add compile time folder for tosa.SLICE_SHAPE
---
Full diff: https://github.com/llvm/llvm-project/pull/186997.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td (+2)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+60)
- (modified) mlir/test/Dialect/Tosa/constant_folding.mlir (+12)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
index f2c48b7684c26..be4d31a4372c8 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -346,6 +346,8 @@ def Tosa_SliceShapeOp : Tosa_ShapeOp<"slice_shape", [Pure]> {
let results = (outs Tosa_Shape:$output);
let hasVerifier = 1;
+
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index b5e067f76a979..ef225dba9dc3e 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1024,6 +1024,16 @@ static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy,
return {};
}
+static FailureOr<int64_t> getSingleI64From1ElementTensor(Value v) {
+ DenseIntElementsAttr dense{};
+ if (!matchPattern(v, m_Constant(&dense)))
+ return failure();
+
+ assert(dense.isSplat());
+ APInt a = dense.getSplatValue<APInt>();
+ return a.getSExtValue();
+}
+
struct AddFoldAdaptor {
static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
const bool isUnsigned) {
@@ -2093,6 +2103,52 @@ OpFoldResult concatShapeFold(tosa::ConcatShapeOp *op) {
return DenseElementsAttr::get(rankedTy, concatDims);
}
+OpFoldResult sliceShapeFold(tosa::SliceShapeOp *op) {
+ auto const input1 = op->getInput();
+ auto const input2 = op->getStart();
+ auto const input3 = op->getSize();
+
+ auto input1ConstShape = dyn_cast<tosa::ConstShapeOp>(input1.getDefiningOp());
+
+ if (!input1ConstShape)
+ return {};
+
+ auto const input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
+ if (!input1Attr)
+ return {};
+
+ auto const input1Vals = input1Attr.getValues<APInt>();
+ auto const totalInput1 = input1Vals.size();
+
+ auto const start = getSingleI64From1ElementTensor(input2);
+ auto const size = getSingleI64From1ElementTensor(input3);
+
+ if (failed(start) || failed(size))
+ return {};
+
+ auto const startV = static_cast<int32_t>(start.value());
+ auto const sizeV = static_cast<int32_t>(size.value());
+
+ if ((sizeV <= 0) || (startV < 0) ||
+ (static_cast<size_t>(startV + sizeV) > totalInput1))
+ return {};
+
+ SmallVector<APInt> sliceOfInput;
+ sliceOfInput.reserve(totalInput1);
+
+ for (auto i = startV; i < (startV + sizeV); i++) {
+ sliceOfInput.push_back(input1Vals[i]);
+ }
+
+ auto *ctx = op->getContext();
+ assert(ctx != nullptr && "ctx is nullptr");
+
+ auto const rankedTy = RankedTensorType::get(
+ {static_cast<int64_t>(sliceOfInput.size())}, IndexType::get(ctx));
+
+ return DenseElementsAttr::get(rankedTy, sliceOfInput);
+}
+
OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) {
return binaryFold<AddShapeOp, AddFoldAdaptor>(this);
}
@@ -2140,3 +2196,7 @@ OpFoldResult tosa::Log2FloorShapeOp::fold(FoldAdaptor adaptor) {
OpFoldResult tosa::ConcatShapeOp::fold(FoldAdaptor adaptor) {
return concatShapeFold(this);
}
+
+OpFoldResult tosa::SliceShapeOp::fold(FoldAdaptor adaptor) {
+ return sliceShapeFold(this);
+}
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index dd35f777bebb2..9eeb5d13317af 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -1469,4 +1469,16 @@ func.func @test_concat_shape_total_rank9_shapes() -> !tosa.shape<9> {
return %abc : !tosa.shape<9>
}
+
+// -----
+
+// CHECK-LABEL: @test_slice_shape
+// CHECK: tosa.const_shape {values = dense<[3, 4, 5, 6]> : tensor<4xindex>} : () -> !tosa.shape<4>
+func.func @test_slice_shape() -> !tosa.shape<4> {
+ %a = tosa.const_shape {values = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %b = "tosa.const"() {values = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
+ %c = "tosa.const"() {values = dense<4> : tensor<1xi32>} : () -> tensor<1xi32>
+ %d = tosa.slice_shape %a, %b, %c : (!tosa.shape<6>, tensor<1xi32>, tensor<1xi32>) -> !tosa.shape<4>
+ return %d : !tosa.shape<4>
+}
// -----
``````````
</details>
https://github.com/llvm/llvm-project/pull/186997
More information about the Mlir-commits
mailing list