[Mlir-commits] [mlir] fffd966 - [mlir][tosa] Added folders for tosa.greater

Rob Suderman llvmlistbot at llvm.org
Mon Aug 29 13:24:40 PDT 2022


Author: Rob Suderman
Date: 2022-08-29T13:20:01-07:00
New Revision: fffd966f0b0f6aff879e2fd60cfc75336beb226b

URL: https://github.com/llvm/llvm-project/commit/fffd966f0b0f6aff879e2fd60cfc75336beb226b
DIFF: https://github.com/llvm/llvm-project/commit/fffd966f0b0f6aff879e2fd60cfc75336beb226b.diff

LOG: [mlir][tosa] Added folders for tosa.greater

Added folders for tosa.greater fold splat values.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D132707

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/test/Dialect/Tosa/constant-op-fold.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 620ee39bd139a..c256451cc190f 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1161,6 +1161,8 @@ def Tosa_GreaterOp : Tosa_Op<"greater", [
   let results = (outs
     I1Tensor:$output
   );
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 99477093b75ac..200d81260bbf5 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -441,20 +441,25 @@ void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
 
 template <typename IntFolder, typename FloatFolder>
 DenseElementsAttr BinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
-                               RankedTensorType ty) {
+                               RankedTensorType returnTy) {
   if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
-    if (ty.getElementType().isa<IntegerType>()) {
+    auto lETy = lhs.getType().cast<ShapedType>().getElementType();
+    auto rETy = rhs.getType().cast<ShapedType>().getElementType();
+    if (lETy != rETy)
+      return {};
+
+    if (lETy.isa<IntegerType>()) {
       APInt l = lhs.getSplatValue<APInt>();
       APInt r = rhs.getSplatValue<APInt>();
-      APInt result = IntFolder()(l, r);
-      return DenseElementsAttr::get(ty, result);
+      auto result = IntFolder()(l, r);
+      return DenseElementsAttr::get(returnTy, result);
     }
 
-    if (ty.getElementType().isa<FloatType>()) {
+    if (lETy.isa<FloatType>()) {
       APFloat l = lhs.getSplatValue<APFloat>();
       APFloat r = rhs.getSplatValue<APFloat>();
-      APFloat result = FloatFolder()(l, r);
-      return DenseElementsAttr::get(ty, result);
+      auto result = FloatFolder()(l, r);
+      return DenseElementsAttr::get(returnTy, result);
     }
   }
 
@@ -501,6 +506,37 @@ OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
                                                             lhsTy);
 }
 
+namespace {
+template <typename Cmp>
+struct ComparisonFold {
+  ComparisonFold() {}
+  APInt operator()(const APInt &l, const APInt &r) {
+    return APInt(1, Cmp()(l, r));
+  }
+
+  APInt operator()(const APFloat &l, const APFloat &r) {
+    return APInt(1, Cmp()(l, r));
+  }
+};
+
+struct APIntFoldGreater {
+  APIntFoldGreater() {}
+  APInt operator()(APInt l, APInt r) { return APInt(1, l.sgt(r)); }
+};
+} // namespace
+
+OpFoldResult GreaterOp::fold(ArrayRef<Attribute> operands) {
+  auto resultTy = getType().dyn_cast<RankedTensorType>();
+  auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
+  auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
+
+  if (!lhsAttr || !rhsAttr)
+    return {};
+
+  return BinaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
+      lhsAttr, rhsAttr, resultTy);
+}
+
 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
   if (getInput().getType() == getType())
     return getInput();

diff  --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index f392e4297be99..4c12dafcfe5d8 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -164,6 +164,54 @@ func.func @fold_add_splat_f32() -> tensor<10xf32> {
 
 // -----
 
+// CHECK-LABEL: @fold_greater_splat_f32_true
+func.func @fold_greater_splat_f32_true() -> tensor<10xi1> {
+  %one = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %add = "tosa.greater"(%one, %two) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+  // CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
+  // CHECK: return %[[BOOL]]
+  return %add : tensor<10xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_greater_splat_f32_false
+func.func @fold_greater_splat_f32_false() -> tensor<10xi1> {
+  %one = "tosa.const"() {value = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %add = "tosa.greater"(%one, %two) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+  // CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
+  // CHECK: return %[[BOOL]]
+  return %add : tensor<10xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_greater_splat_i32_false
+func.func @fold_greater_splat_i32_false() -> tensor<10xi1> {
+  %one = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
+  %two = "tosa.const"() {value = dense<8> : tensor<10xi32>} : () -> tensor<10xi32>
+  %add = "tosa.greater"(%one, %two) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
+  // CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
+  // CHECK: return %[[BOOL]]
+  return %add : tensor<10xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_greater_splat_i32_true
+func.func @fold_greater_splat_i32_true() -> tensor<10xi1> {
+  %one = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
+  %two = "tosa.const"() {value = dense<-12> : tensor<10xi32>} : () -> tensor<10xi32>
+  %add = "tosa.greater"(%one, %two) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
+  // CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
+  // CHECK: return %[[BOOL]]
+  return %add : tensor<10xi1>
+}
+
+// -----
+
 // CHECK-LABEL: @slice_splat
 func.func @slice_splat() -> tensor<1x1x1xi32> {
   // CHECK: %[[SLICE:.+]] = "tosa.const"() {value = dense<42> : tensor<1x1x1xi32>}


        


More information about the Mlir-commits mailing list