[Mlir-commits] [mlir] 3e49a3e - [mlir][tosa] Added tosa.reverse folder
Rob Suderman
llvmlistbot at llvm.org
Mon Sep 12 17:30:22 PDT 2022
Author: Rob Suderman
Date: 2022-09-12T17:15:17-07:00
New Revision: 3e49a3e89dbd44bfeb7611d76c23e137e1676abe
URL: https://github.com/llvm/llvm-project/commit/3e49a3e89dbd44bfeb7611d76c23e137e1676abe
DIFF: https://github.com/llvm/llvm-project/commit/3e49a3e89dbd44bfeb7611d76c23e137e1676abe.diff
LOG: [mlir][tosa] Added tosa.reverse folder
Fold cases where a tosa.reverse is a splat or reversing a dim
of length-1.
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D133144
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
mlir/test/Dialect/Tosa/constant-op-fold.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index c50fee2c5185b..8518d6bf54842 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1464,6 +1464,8 @@ def Tosa_ReverseOp: Tosa_Op<"reverse", [
let results = (outs
Tosa_Tensor1Dto4D:$output
);
+
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index abe33f8366d52..b346b11cbb1ed 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -846,6 +846,21 @@ OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
return {};
}
+OpFoldResult ReverseOp::fold(ArrayRef<Attribute> operands) {
+ auto operand = getInput();
+ auto operandTy = operand.getType().cast<ShapedType>();
+ auto axis = getAxis();
+ auto operandAttr = operands[0].dyn_cast_or_null<SplatElementsAttr>();
+ if (operandAttr)
+ return operandAttr;
+
+ // If the dim-length is 1, tosa.reverse is a no-op.
+ if (operandTy.hasRank() && operandTy.getDimSize(axis) == 1)
+ return operand;
+
+ return {};
+}
+
OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
auto inputTy = getInput().getType().dyn_cast<RankedTensorType>();
auto outputTy = getType().dyn_cast<RankedTensorType>();
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 28be1ff67649a..08115787db58a 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -540,3 +540,25 @@ func.func @cast_int_to_int_sign() -> tensor<i32> {
// CHECK: return %[[SPLAT]]
return %cast : tensor<i32>
}
+
+// -----
+
+// CHECK-LABEL: @reverse_splat
+func.func @reverse_splat() -> tensor<10xi32> {
+ // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<42> : tensor<10xi32>}
+ %splat = "tosa.const"() {value = dense<42> : tensor<10xi32>} : () -> tensor<10xi32>
+ %reverse = "tosa.reverse"(%splat) { axis = 0 : i64 } : (tensor<10xi32>) -> tensor<10xi32>
+ // CHECK: return %[[SPLAT]]
+ return %reverse : tensor<10xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @reverse_length_one
+func.func @reverse_length_one(%arg0 : tensor<10x1xi32>) -> (tensor<10x1xi32>, tensor<10x1xi32>) {
+ %nofold = "tosa.reverse"(%arg0) { axis = 0 : i64 } : (tensor<10x1xi32>) -> tensor<10x1xi32>
+ %fold = "tosa.reverse"(%arg0) { axis = 1 : i64 } : (tensor<10x1xi32>) -> tensor<10x1xi32>
+ // CHECK: %[[NOFOLD:.+]] = "tosa.reverse"(%arg0) {axis = 0 : i64}
+ // CHECK: return %[[NOFOLD]], %arg0
+ return %nofold, %fold : tensor<10x1xi32>, tensor<10x1xi32>
+}
More information about the Mlir-commits
mailing list