[Mlir-commits] [mlir] f5c8c9d - [mlir][tosa] Added folders for tosa.add

Rob Suderman llvmlistbot at llvm.org
Wed Aug 24 15:15:13 PDT 2022


Author: Rob Suderman
Date: 2022-08-24T15:13:02-07:00
New Revision: f5c8c9d51cb766c99d90c5ca87aab6e18c43c634

URL: https://github.com/llvm/llvm-project/commit/f5c8c9d51cb766c99d90c5ca87aab6e18c43c634
DIFF: https://github.com/llvm/llvm-project/commit/f5c8c9d51cb766c99d90c5ca87aab6e18c43c634.diff

LOG: [mlir][tosa] Added folders for tosa.add

Added folders for tosa.add that handles bypassing add-zero,
fold additions of two splat tensors, and additions between
two tensors with small values.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D132272

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/test/Dialect/Tosa/constant-op-fold.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 9fcd955693f08..620ee39bd139a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -420,6 +420,7 @@ def Tosa_AddOp : Tosa_Op<"add", [
   );
 
   let hasCanonicalizer = 1;
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index d1dcbec32f67e..571cd52eea89d 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -26,6 +26,8 @@
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/TypeSwitch.h"
 
+#include <functional>
+
 using namespace mlir;
 using namespace mlir::tosa;
 
@@ -437,6 +439,68 @@ void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
 // Operator Folders.
 //===----------------------------------------------------------------------===//
 
+template <typename IntFolder, typename FloatFolder>
+DenseElementsAttr BinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
+                               RankedTensorType ty) {
+  if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
+    if (ty.getElementType().isa<IntegerType>()) {
+      APInt l = lhs.getSplatValue<APInt>();
+      APInt r = rhs.getSplatValue<APInt>();
+      APInt result = IntFolder()(l, r);
+      return DenseElementsAttr::get(ty, result);
+    }
+
+    if (ty.getElementType().isa<FloatType>()) {
+      APFloat l = lhs.getSplatValue<APFloat>();
+      APFloat r = rhs.getSplatValue<APFloat>();
+      APFloat result = FloatFolder()(l, r);
+      return DenseElementsAttr::get(ty, result);
+    }
+  }
+
+  return {};
+}
+
+OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
+  auto lhsTy = getInput1().getType().dyn_cast<RankedTensorType>();
+  auto rhsTy = getInput2().getType().dyn_cast<RankedTensorType>();
+  auto resultTy = getType().dyn_cast<RankedTensorType>();
+  if (!lhsTy || !rhsTy || !resultTy)
+    return {};
+  if (lhsTy != rhsTy)
+    return {};
+
+  auto resultETy = resultTy.getElementType();
+  auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
+  auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
+
+  if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<FloatType>()) {
+    if (lhsAttr.getSplatValue<APFloat>().isZero())
+      return getInput2();
+  }
+
+  if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<FloatType>()) {
+    if (rhsAttr.getSplatValue<APFloat>().isZero())
+      return getInput1();
+  }
+
+  if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
+    if (lhsAttr.getSplatValue<APInt>().isZero())
+      return getInput2();
+  }
+
+  if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
+    if (rhsAttr.getSplatValue<APInt>().isZero())
+      return getInput1();
+  }
+
+  if (!lhsAttr || !rhsAttr)
+    return {};
+
+  return BinaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
+                                                            lhsTy);
+}
+
 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
   if (getInput().getType() == getType())
     return getInput();

diff  --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 09f8245e771a7..edf97144e0902 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -97,3 +97,67 @@ func.func @transpose_nofold_quantized_types() -> tensor<1x1x16x1x!quant.uniform<
   %0 = "tosa.transpose"(%input, %perms) : (tensor<1x1x1x16xi8>, tensor<4xi32>) -> tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>
   return %0: tensor<1x1x16x1x!quant.uniform<i8<-127:127>:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>
 }
+
+// -----
+
+// CHECK-LABEL: @fold_add_zero_rhs_f32
+func.func @fold_add_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
+  %zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
+  %add = "tosa.add"(%arg0, %zero) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  // CHECK: return %arg0
+  return %add : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_add_zero_lhs_f32
+func.func @fold_add_zero_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
+  %zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
+  %add = "tosa.add"(%zero, %arg0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  // CHECK: return %arg0
+  return %add : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_add_zero_rhs_i32
+func.func @fold_add_zero_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
+  %zero = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+  %add = "tosa.add"(%arg0, %zero) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  // CHECK: return %arg0
+  return %add : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_add_zero_lhs_i32
+func.func @fold_add_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
+  %zero = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
+  %add = "tosa.add"(%zero, %arg0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  // CHECK: return %arg0
+  return %add : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_add_splat_i32
+func.func @fold_add_splat_i32() -> tensor<10xi32> {
+  %one = "tosa.const"() {value = dense<1> : tensor<10xi32>} : () -> tensor<10xi32>
+  %two = "tosa.const"() {value = dense<2> : tensor<10xi32>} : () -> tensor<10xi32>
+  %add = "tosa.add"(%one, %two) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
+  // CHECK: %[[THREE:.+]] = "tosa.const"() {value = dense<3> : tensor<10xi32>}
+  // CHECK: return %[[THREE]]
+  return %add : tensor<10xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_add_splat_f32
+func.func @fold_add_splat_f32() -> tensor<10xf32> {
+  %one = "tosa.const"() {value = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %add = "tosa.add"(%one, %two) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
+  // CHECK: %[[THREE:.+]] = "tosa.const"() {value = dense<3.000000e+00> : tensor<10xf32>}
+  // CHECK: return %[[THREE]]
+  return %add : tensor<10xf32>
+}


        


More information about the Mlir-commits mailing list