[Mlir-commits] [mlir] 5a23172 - [mlir][tosa] Add remaining tosa comparison folders

Rob Suderman llvmlistbot at llvm.org
Thu Sep 1 14:50:12 PDT 2022


Author: Rob Suderman
Date: 2022-09-01T14:48:46-07:00
New Revision: 5a231720bc0619aa8744d47470fee08afc643b4d

URL: https://github.com/llvm/llvm-project/commit/5a231720bc0619aa8744d47470fee08afc643b4d
DIFF: https://github.com/llvm/llvm-project/commit/5a231720bc0619aa8744d47470fee08afc643b4d.diff

LOG: [mlir][tosa] Add remaining tosa comparison folders

Added numerical splat folders for comparison operations and
equal of two identical int values.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D133138

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/test/Dialect/Tosa/constant-op-fold.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index c19c9e4346d0..c50fee2c5185 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1143,6 +1143,8 @@ def Tosa_EqualOp : Tosa_Op<"equal", [InferTensorType, ResultsBroadcastableShape,
     /// InferTypeOpInterface.
     static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
   }];
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1191,6 +1193,8 @@ def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [
   let results = (outs
     I1Tensor:$output
   );
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 385e4068011f..abe33f8366d5 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -675,6 +675,11 @@ struct APIntFoldGreater {
   APIntFoldGreater() {}
   APInt operator()(APInt l, APInt r) { return APInt(1, l.sgt(r)); }
 };
+
+struct APIntFoldGreaterEqual {
+  APIntFoldGreaterEqual() {}
+  APInt operator()(APInt l, APInt r) { return APInt(1, l.sge(r)); }
+};
 } // namespace
 
 OpFoldResult GreaterOp::fold(ArrayRef<Attribute> operands) {
@@ -689,6 +694,42 @@ OpFoldResult GreaterOp::fold(ArrayRef<Attribute> operands) {
       lhsAttr, rhsAttr, resultTy);
 }
 
+OpFoldResult GreaterEqualOp::fold(ArrayRef<Attribute> operands) {
+  auto resultTy = getType().dyn_cast<RankedTensorType>();
+  auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
+  auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
+
+  if (!lhsAttr || !rhsAttr)
+    return {};
+
+  return BinaryFolder<APIntFoldGreaterEqual,
+                      ComparisonFold<std::greater_equal<APFloat>>>(
+      lhsAttr, rhsAttr, resultTy);
+}
+
+OpFoldResult EqualOp::fold(ArrayRef<Attribute> operands) {
+  auto resultTy = getType().dyn_cast<RankedTensorType>();
+  auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
+  auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
+  Value lhs = getInput1();
+  Value rhs = getInput2();
+  auto lhsTy = lhs.getType().cast<ShapedType>();
+
+  // If we are comparing an integer value to itself it is always true. We can
+  // not do this with float due to float values.
+  if (lhsTy.getElementType().isa<IntegerType>() && resultTy.hasStaticShape() &&
+      lhs == rhs) {
+    return DenseElementsAttr::get(resultTy, true);
+  }
+
+  if (!lhsAttr || !rhsAttr)
+    return {};
+
+  return BinaryFolder<ComparisonFold<std::equal_to<APInt>>,
+                      ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
+                                                              resultTy);
+}
+
 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
   if (getInput().getType() == getType())
     return getInput();

diff  --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 43f2196c834d..28be1ff67649 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -350,50 +350,108 @@ func.func @fold_sub_splat_f32() -> tensor<10xf32> {
 
 // -----
 
-// CHECK-LABEL: @fold_greater_splat_f32_true
-func.func @fold_greater_splat_f32_true() -> tensor<10xi1> {
-  %one = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %add = "tosa.greater"(%one, %two) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
-  // CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
-  // CHECK: return %[[BOOL]]
-  return %add : tensor<10xi1>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_greater_splat_f32_false
-func.func @fold_greater_splat_f32_false() -> tensor<10xi1> {
-  %one = "tosa.const"() {value = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
-  %add = "tosa.greater"(%one, %two) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
-  // CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
-  // CHECK: return %[[BOOL]]
-  return %add : tensor<10xi1>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_greater_splat_i32_false
-func.func @fold_greater_splat_i32_false() -> tensor<10xi1> {
-  %one = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
-  %two = "tosa.const"() {value = dense<8> : tensor<10xi32>} : () -> tensor<10xi32>
-  %add = "tosa.greater"(%one, %two) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
-  // CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
-  // CHECK: return %[[BOOL]]
-  return %add : tensor<10xi1>
-}
-
-// -----
-
-// CHECK-LABEL: @fold_greater_splat_i32_true
-func.func @fold_greater_splat_i32_true() -> tensor<10xi1> {
-  %one = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
-  %two = "tosa.const"() {value = dense<-12> : tensor<10xi32>} : () -> tensor<10xi32>
-  %add = "tosa.greater"(%one, %two) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
-  // CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
-  // CHECK: return %[[BOOL]]
-  return %add : tensor<10xi1>
+// CHECK-LABEL: @fold_greater_splat_f32
+func.func @fold_greater_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
+  %0 = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %1 = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %2 = "tosa.const"() {value = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %3 = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %true = "tosa.greater"(%0, %1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+  %false = "tosa.greater"(%2, %3) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+  // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
+  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
+  // CHECK: return %[[TRUE]], %[[FALSE]]
+  return %true, %false : tensor<10xi1>, tensor<10xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_greater_splat_i32
+func.func @fold_greater_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) {
+  %0 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
+  %1 = "tosa.const"() {value = dense<8> : tensor<10xi32>} : () -> tensor<10xi32>
+  %2 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
+  %3 = "tosa.const"() {value = dense<-12> : tensor<10xi32>} : () -> tensor<10xi32>
+  %false = "tosa.greater"(%0, %1) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
+  %true = "tosa.greater"(%2, %3) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
+  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
+  // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
+  // CHECK: return %[[FALSE]], %[[TRUE]]
+  return %false, %true : tensor<10xi1>, tensor<10xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_greater_eq_splat_f32
+func.func @fold_greater_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
+  %0 = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %1 = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %2 = "tosa.const"() {value = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %3 = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %true = "tosa.greater_equal"(%0, %1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+  %false = "tosa.greater_equal"(%2, %3) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+  // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
+  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
+  // CHECK: return %[[TRUE]], %[[FALSE]]
+  return %true, %false : tensor<10xi1>, tensor<10xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_greater_eq_splat_i32
+func.func @fold_greater_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) {
+  %0 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
+  %1 = "tosa.const"() {value = dense<8> : tensor<10xi32>} : () -> tensor<10xi32>
+  %2 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
+  %3 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
+  %true = "tosa.greater_equal"(%2, %3) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
+  %false = "tosa.greater_equal"(%0, %1) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
+  // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
+  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
+  // CHECK: return %[[TRUE]], %[[FALSE]]
+  return %true, %false : tensor<10xi1>, tensor<10xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_eq_splat_f32
+func.func @fold_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
+  %0 = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %1 = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %2 = "tosa.const"() {value = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %3 = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %true = "tosa.equal"(%0, %1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+  %false = "tosa.equal"(%2, %3) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+  // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
+  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
+  // CHECK: return %[[TRUE]], %[[FALSE]]
+  return %true, %false : tensor<10xi1>, tensor<10xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_eq_splat_i32
+func.func @fold_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) {
+  %0 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
+  %1 = "tosa.const"() {value = dense<8> : tensor<10xi32>} : () -> tensor<10xi32>
+  %2 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
+  %3 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
+  %true = "tosa.equal"(%2, %3) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
+  %false = "tosa.equal"(%0, %1) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
+  // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
+  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
+  // CHECK: return %[[TRUE]], %[[FALSE]]
+  return %true, %false : tensor<10xi1>, tensor<10xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_eq_i32
+func.func @fold_eq_i32(%arg0 : tensor<10xi32>) -> (tensor<10xi1>) {
+  // CHECK: %[[TRUE:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
+  %0 = "tosa.equal"(%arg0, %arg0) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
+  // CHECK: return %[[TRUE]]
+  return %0 : tensor<10xi1>
 }
 
 // -----


        


More information about the Mlir-commits mailing list