[Mlir-commits] [mlir] [mlir][tosa] Fold 'small' constant 1D slice operations (PR #128193)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 21 08:30:25 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Tai Ly (Tai78641)
<details>
<summary>Changes</summary>
This commit extends the slice folder to fold constant slice operations consisting of all constant inputs where the number of output values does not exceed 6. Keeping the folder restricted to small inputs avoids a large folder runtime or increased memory requirements.
This folder is useful in the context of legalizing dynamic models where the input shapes are resolved to static directly before legalization. In this context, constant shape operations are used over tensors of num elements <= 6 (tosa_level_8k MAX_RANK).
---
Full diff: https://github.com/llvm/llvm-project/pull/128193.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+30-7)
- (modified) mlir/test/Dialect/Tosa/constant_folding.mlir (+13)
``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9bfc2aae1d6a5..f5a21689c0af3 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1115,18 +1115,41 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
}
- if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
- outputTy.getNumElements() == 1) {
- DenseElementsAttr startElems;
- if (!matchPattern(getStart(), m_Constant(&startElems)))
- return {};
+ if (!inputTy.hasStaticShape() || !outputTy.hasStaticShape())
+ return {};
+
+ DenseElementsAttr startElems;
+ if (!matchPattern(getStart(), m_Constant(&startElems)))
+ return {};
- llvm::SmallVector<uint64_t> indices =
- llvm::to_vector(startElems.getValues<uint64_t>());
+ llvm::SmallVector<uint64_t> indices =
+ llvm::to_vector(startElems.getValues<uint64_t>());
+
+ if (outputTy.getNumElements() == 1) {
auto value = operand.getValues<Attribute>()[indices];
return SplatElementsAttr::get(outputTy, value);
}
+ DenseElementsAttr size_elems;
+ if (!matchPattern(getSize(), m_Constant(&size_elems)))
+ return {};
+ const llvm::SmallVector<uint64_t> sizes =
+ llvm::to_vector(size_elems.getValues<uint64_t>());
+
+ // Fold slice when all operands are constant and the output is 'small'
+ // A 'small' output is currently defined as 1D and <= 6 elements
+ // (tosa_level_8k MAX_RANK)
+ if (inputTy.getRank() == 1 && outputTy.getRank() == 1 &&
+ outputTy.getNumElements() <= 6 && indices.size() == 1 &&
+ sizes.size() == 1) {
+ const auto begin = operand.value_begin<Attribute>();
+ const uint64_t offset = indices[0];
+ const uint64_t size = sizes[0];
+ const SmallVector<Attribute> slicedValues(begin + offset,
+ begin + offset + size);
+ return DenseElementsAttr::get(outputTy, slicedValues);
+ }
+
return {};
}
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index 3ff3121348fca..c82f522432295 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -21,3 +21,16 @@ func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tens
%0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1>
return
}
+
+// -----
+
+// CHECK-LABEL: test_1d_slice
+func.func @test_1d_slice() -> tensor<6xi32> {
+ // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{value = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>}> : () -> tensor<6xi32>
+ // CHECK: return %[[VAL_0]] : tensor<6xi32>
+ %0 = "tosa.const"() <{value = dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xi32>}> : () -> tensor<10xi32>
+ %1 = tosa.const_shape {value = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
+ %2 = tosa.const_shape {value = dense<6> : tensor<1xindex>} : () -> !tosa.shape<1>
+ %3 = tosa.slice %0, %1, %2 : (tensor<10xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<6xi32>
+ return %3 : tensor<6xi32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/128193
More information about the Mlir-commits
mailing list