[Mlir-commits] [mlir] [mlir][tosa] Fold 'small' constant 1D slice operations (PR #128193)

Tai Ly llvmlistbot at llvm.org
Thu Mar 6 19:02:08 PST 2025


https://github.com/Tai78641 updated https://github.com/llvm/llvm-project/pull/128193

>From 7fd38bf8ad5cf9cc47c3453a80042af913ce7129 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Tue, 29 Oct 2024 14:58:20 +0000
Subject: [PATCH] [mlir][tosa] Fold 'small' constant 1D slice operations

This commit extends the slice folder to fold constant slice operations
consisting of all constant inputs where the number of output values
does not exceed 6. Keeping the folder restricted to small inputs avoids
a large folder runtime or increased memory requirements.

This folder is useful in the context of legalizing dynamic models where
the input shapes are resolved to static directly before legalization.
In this context, constant shape operations are used over tensors of
num elements <= 6 (tosa_level_8k MAX_RANK).

Change-Id: I1e59e5919f8c2936e98788c5a9b44a691940b28a
Signed-off-by: Luke Hutton <luke.hutton at arm.com>
---
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 36 ++++++++++---
 mlir/test/Dialect/Tosa/constant_folding.mlir  | 51 +++++++++++++++++++
 2 files changed, 80 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 3e99c1f717d09..ea37a76360ed2 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1054,18 +1054,40 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
     return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
   }
 
-  if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
-      outputTy.getNumElements() == 1) {
-    DenseElementsAttr startElems;
-    if (!matchPattern(getStart(), m_Constant(&startElems)))
-      return {};
+  if (!inputTy.hasStaticShape() || !outputTy.hasStaticShape())
+    return {};
+
+  DenseElementsAttr startElems;
+  if (!matchPattern(getStart(), m_Constant(&startElems)))
+    return {};
 
-    llvm::SmallVector<uint64_t> indices =
-        llvm::to_vector(startElems.getValues<uint64_t>());
+  auto indices = llvm::to_vector(startElems.getValues<uint64_t>());
+
+  if (outputTy.getNumElements() == 1) {
     auto value = operand.getValues<Attribute>()[indices];
     return SplatElementsAttr::get(outputTy, value);
   }
 
+  DenseElementsAttr size_elems;
+  if (!matchPattern(getSize(), m_Constant(&size_elems)))
+    return {};
+
+  const auto sizes = llvm::to_vector(size_elems.getValues<uint64_t>());
+
+  // Fold slice when all operands are constant and the output is 'small'
+  // A 'small' output is currently defined as 1D and <= 6 elements
+  // (tosa_level_8k MAX_RANK)
+  if (inputTy.getRank() == 1 && outputTy.getRank() == 1 &&
+      outputTy.getNumElements() <= 6 && indices.size() == 1 &&
+      sizes.size() == 1) {
+    const auto begin = operand.value_begin<Attribute>();
+    const uint64_t offset = indices[0];
+    const uint64_t size = sizes[0];
+    const SmallVector<Attribute> slicedValues(begin + offset,
+                                              begin + offset + size);
+    return DenseElementsAttr::get(outputTy, slicedValues);
+  }
+
   return {};
 }
 
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index 9b6ccdb54c107..3c1f2c5058b95 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -21,3 +21,54 @@ func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tens
   %0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1>
   return
 }
+
+// -----
+
+// CHECK-LABEL: test_1d_slice
+func.func @test_1d_slice() -> tensor<6xi32> {
+  // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{values = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>}> : () -> tensor<6xi32>
+  // CHECK: return %[[VAL_0]] : tensor<6xi32>
+  %0 = "tosa.const"() <{values = dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xi32>}> : () -> tensor<10xi32>
+  %1 = tosa.const_shape {values = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %2 = tosa.const_shape {values = dense<6> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %3 = tosa.slice %0, %1, %2 : (tensor<10xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<6xi32>
+  return %3 : tensor<6xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_1d_slice_non_const_input
+func.func @test_1d_slice_non_const_input(%arg0 : tensor<10xi32>) -> tensor<6xi32> {
+  // check that slice is not folded for non-constant input1
+  // CHECK: tosa.slice
+  %1 = tosa.const_shape {values = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %2 = tosa.const_shape {values = dense<6> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %3 = tosa.slice %arg0, %1, %2 : (tensor<10xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<6xi32>
+  return %3 : tensor<6xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_1d_slice_rank_2_input
+func.func @test_1d_slice_rank_2_input(%arg0 : tensor<1x10xi32>) -> tensor<1x6xi32> {
+  // check that slice is not folded for input1 rank > 1
+  // CHECK: tosa.slice
+  %0 = "tosa.const"() <{values = dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]> : tensor<1x10xi32>}> : () -> tensor<1x10xi32>
+  %1 = tosa.const_shape {values = dense<[0, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %2 = tosa.const_shape {values = dense<[1, 6]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %3 = tosa.slice %arg0, %1, %2 : (tensor<1x10xi32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x6xi32>
+  return %3 : tensor<1x6xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_1d_slice_more_than_6
+func.func @test_1d_slice_more_than_6() -> tensor<7xi32> {
+  // check that slice is not folded because output has more than 6 elements
+  // CHECK: tosa.slice
+  %0 = "tosa.const"() <{values = dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xi32>}> : () -> tensor<10xi32>
+  %1 = tosa.const_shape {values = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %2 = tosa.const_shape {values = dense<7> : tensor<1xindex>} : () -> !tosa.shape<1>
+  %3 = tosa.slice %0, %1, %2 : (tensor<10xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<7xi32>
+  return %3 : tensor<7xi32>
+}



More information about the Mlir-commits mailing list