[Mlir-commits] [mlir] 89d5551 - [mlir][tosa] Add constant folding for tosa.slice

Rob Suderman llvmlistbot at llvm.org
Wed Aug 24 15:34:52 PDT 2022


Author: Rob Suderman
Date: 2022-08-24T15:34:02-07:00
New Revision: 89d555134aa1bf6f80aea3043354166cdaaae016

URL: https://github.com/llvm/llvm-project/commit/89d555134aa1bf6f80aea3043354166cdaaae016
DIFF: https://github.com/llvm/llvm-project/commit/89d555134aa1bf6f80aea3043354166cdaaae016.diff

LOG: [mlir][tosa] Add constant folding for tosa.slice

If the input to a tosa.slice operation is a splat we can just replace with
another splat. If the result is a single element, replacing with a splat
is universally useful.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D132499

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/test/Dialect/Tosa/constant-op-fold.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 571cd52eea89d..99477093b75ac 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -555,11 +555,30 @@ OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
   auto inputTy = getInput().getType().dyn_cast<RankedTensorType>();
   auto outputTy = getType().dyn_cast<RankedTensorType>();
 
-  if (!inputTy || !outputTy || inputTy != outputTy)
+  if (!inputTy || !outputTy)
     return {};
-  if (inputTy.hasStaticShape())
+
+  if (inputTy == outputTy && inputTy.hasStaticShape())
     return getInput();
 
+  if (!operands[0])
+    return {};
+
+  auto operand = operands[0].cast<ElementsAttr>();
+  if (operand.isSplat() && outputTy.hasStaticShape()) {
+    return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
+  }
+
+  if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
+      outputTy.getNumElements() == 1) {
+    llvm::SmallVector<uint64_t> indices;
+    for (auto val : getStart()) {
+      indices.push_back(val.cast<IntegerAttr>().getInt());
+    }
+    auto value = operand.getValues<Attribute>()[indices];
+    return SplatElementsAttr::get(outputTy, value);
+  }
+
   return {};
 }
 

diff  --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index edf97144e0902..f392e4297be99 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -161,3 +161,25 @@ func.func @fold_add_splat_f32() -> tensor<10xf32> {
   // CHECK: return %[[THREE]]
   return %add : tensor<10xf32>
 }
+
+// -----
+
+// CHECK-LABEL: @slice_splat
+func.func @slice_splat() -> tensor<1x1x1xi32> {
+  // CHECK: %[[SLICE:.+]] = "tosa.const"() {value = dense<42> : tensor<1x1x1xi32>}
+  %splat = "tosa.const"() {value = dense<42> : tensor<4x5x6xi32>} : () -> tensor<4x5x6xi32>
+  %slice = "tosa.slice"(%splat) { size = [1, 1, 1], start = [1, 2, 3] } : (tensor<4x5x6xi32>) -> tensor<1x1x1xi32>
+  // CHECK: return %[[SLICE]]
+  return %slice : tensor<1x1x1xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @slice_singleton
+func.func @slice_singleton() -> tensor<1x1xi32> {
+  %splat = "tosa.const"() {value = dense<[[0, 1, 2], [3, 4, 5], [6, 7 ,8]]> : tensor<3x3xi32>} : () -> tensor<3x3xi32>
+  // CHECK: %[[SLICE:.+]] = "tosa.const"() {value = dense<4> : tensor<1x1xi32>}
+  %slice = "tosa.slice"(%splat) { size = [1, 1], start = [1, 1] } : (tensor<3x3xi32>) -> tensor<1x1xi32>
+  // CHECK: return %[[SLICE]]
+  return %slice : tensor<1x1xi32>
+}


        


More information about the Mlir-commits mailing list