[Mlir-commits] [mlir] f19a469 - [mlir][tosa] Added folders for tosa.sub

Rob Suderman llvmlistbot at llvm.org
Mon Aug 29 16:23:29 PDT 2022


Author: Rob Suderman
Date: 2022-08-29T16:22:27-07:00
New Revision: f19a4695a7c55e249a08730fa7d83d6b0b94dbf6

URL: https://github.com/llvm/llvm-project/commit/f19a4695a7c55e249a08730fa7d83d6b0b94dbf6
DIFF: https://github.com/llvm/llvm-project/commit/f19a4695a7c55e249a08730fa7d83d6b0b94dbf6.diff

LOG: [mlir][tosa] Added folders for tosa.sub

Added folders for tosa.sub that handles bypassing sub-zero,
fold subtraction of two splat tensors.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D132618

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/test/Dialect/Tosa/constant-op-fold.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index c256451cc190f..ff8c47c3b1a4e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -785,6 +785,8 @@ def Tosa_SubOp : Tosa_Op<"sub", [
   let results = (outs
     Tosa_Tensor:$output
   );
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 200d81260bbf5..39c74974ab1f6 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -506,6 +506,36 @@ OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
                                                             lhsTy);
 }
 
+OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
+  auto lhsTy = getInput1().getType().dyn_cast<RankedTensorType>();
+  auto rhsTy = getInput2().getType().dyn_cast<RankedTensorType>();
+  auto resultTy = getType().dyn_cast<RankedTensorType>();
+  if (!lhsTy || !rhsTy || !resultTy)
+    return {};
+  if (lhsTy != rhsTy)
+    return {};
+
+  auto resultETy = resultTy.getElementType();
+  auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
+  auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
+
+  if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<FloatType>()) {
+    if (rhsAttr.getSplatValue<APFloat>().isZero())
+      return getInput1();
+  }
+
+  if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
+    if (rhsAttr.getSplatValue<APInt>().isZero())
+      return getInput1();
+  }
+
+  if (!lhsAttr || !rhsAttr)
+    return {};
+
+  return BinaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
+                                                              lhsTy);
+}
+
 namespace {
 template <typename Cmp>
 struct ComparisonFold {

diff  --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 4c12dafcfe5d8..f023cc3571ee5 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -164,6 +164,50 @@ func.func @fold_add_splat_f32() -> tensor<10xf32> {
 
 // -----
 
+// CHECK-LABEL: @fold_sub_zero_rhs_f32
+func.func @fold_sub_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
+  %zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
+  %sub = "tosa.sub"(%arg0, %zero) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  // CHECK: return %arg0
+  return %sub : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_sub_zero_rhs_i32
+func.func @fold_sub_zero_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
+  %zero = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+  %sub = "tosa.sub"(%arg0, %zero) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  // CHECK: return %arg0
+  return %sub : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_sub_splat_i32
+func.func @fold_sub_splat_i32() -> tensor<10xi32> {
+  %one = "tosa.const"() {value = dense<1> : tensor<10xi32>} : () -> tensor<10xi32>
+  %two = "tosa.const"() {value = dense<2> : tensor<10xi32>} : () -> tensor<10xi32>
+  %sub = "tosa.sub"(%one, %two) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
+  // CHECK: %[[THREE:.+]] = "tosa.const"() {value = dense<-1> : tensor<10xi32>}
+  // CHECK: return %[[THREE]]
+  return %sub : tensor<10xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_sub_splat_f32
+func.func @fold_sub_splat_f32() -> tensor<10xf32> {
+  %one = "tosa.const"() {value = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %sub = "tosa.sub"(%one, %two) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+  // CHECK: %[[THREE:.+]] = "tosa.const"() {value = dense<-1.000000e+00> : tensor<10xf32>}
+  // CHECK: return %[[THREE]]
+  return %sub : tensor<10xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @fold_greater_splat_f32_true
 func.func @fold_greater_splat_f32_true() -> tensor<10xi1> {
   %one = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>


        


More information about the Mlir-commits mailing list