[Mlir-commits] [mlir] [mlir][tosa] Enhance folder for Tosa binary operators (PR #128059)

Tai Ly llvmlistbot at llvm.org
Thu Feb 20 12:14:18 PST 2025


https://github.com/Tai78641 created https://github.com/llvm/llvm-project/pull/128059

This enhances folder for tosa binary operators to support non-splat constant attributes for following ops:
 - mul
 - add
 - sub
 - greater
 - greater_equal
 - equal

>From c4bbd39bd588bc2038d435aff4d85c89712ef3cf Mon Sep 17 00:00:00 2001
From: Tai Ly <tai.ly at arm.com>
Date: Tue, 4 Jun 2024 18:01:33 +0000
Subject: [PATCH] [mlir][tosa] Enhance folder for Tosa binary operators

This enhances folder for tosa binary operators to support non-splat constant attributes for
following ops:
 - mul
 - add
 - sub
 - greater
 - greater_equal
 - equal

Signed-off-by: Tai Ly <tai.ly at arm.com>
Change-Id: I3198a808988a71b5894d8f7c410b407340564c38
---
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 180 ++++++++++++---
 mlir/test/Dialect/Tosa/constant-op-fold.mlir  |   7 +-
 mlir/test/Dialect/Tosa/constant_folding.mlir  | 218 +++++++++++++++++-
 3 files changed, 365 insertions(+), 40 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9bfc2aae1d6a5..2c6c6e2ed284c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -563,15 +563,21 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
 // Operator Folders.
 //===----------------------------------------------------------------------===//
 
-template <typename IntFolder, typename FloatFolder>
+template <typename IntFolder, typename FloatFolder, typename FloatResultAPType>
 DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
                                RankedTensorType returnTy) {
-  if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
-    auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
-    auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
-    if (lETy != rETy)
-      return {};
+  if (!rhs || !lhs)
+    return {};
+
+  auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
+  auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
+  if (lETy != rETy)
+    return {};
+
+  if (!lETy.isIntOrFloat())
+    return {};
 
+  if (rhs.isSplat() && lhs.isSplat()) {
     if (llvm::isa<IntegerType>(lETy)) {
       APInt l = lhs.getSplatValue<APInt>();
       APInt r = rhs.getSplatValue<APInt>();
@@ -587,9 +593,54 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
     }
   }
 
+  if (llvm::isa<IntegerType>(lETy)) {
+    auto lvalues = lhs.getValues<APInt>();
+    auto rvalues = rhs.getValues<APInt>();
+    if (lvalues.size() != rvalues.size()) {
+      return {};
+    }
+    SmallVector<APInt> results;
+    for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
+      auto result = IntFolder()(l, r);
+      results.push_back(result);
+    }
+    return DenseElementsAttr::get(returnTy, results);
+  }
+
+  if (llvm::isa<FloatType>(lETy)) {
+    auto lvalues = lhs.getValues<APFloat>();
+    auto rvalues = rhs.getValues<APFloat>();
+    if (lvalues.size() != rvalues.size()) {
+      return {};
+    }
+    // FloatFolder() may return either APFloat or APInt (comparison functions)
+    SmallVector<FloatResultAPType> results;
+    for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
+      auto result = FloatFolder()(l, r);
+      results.push_back(result);
+    }
+    return DenseElementsAttr::get(returnTy, results);
+  }
+
   return {};
 }
 
+template <typename IntFolder, typename FloatFolder>
+DenseElementsAttr comparisonBinaryFolder(DenseElementsAttr lhs,
+                                         DenseElementsAttr rhs,
+                                         RankedTensorType returnTy) {
+  // comparison FloatFolder() functions return APInt values
+  return binaryFolder<IntFolder, FloatFolder, APInt>(lhs, rhs, returnTy);
+}
+
+template <typename IntFolder, typename FloatFolder>
+DenseElementsAttr arithmeticBinaryFolder(DenseElementsAttr lhs,
+                                         DenseElementsAttr rhs,
+                                         RankedTensorType returnTy) {
+  // arithmetic FloatFolder() functions return APFloat values
+  return binaryFolder<IntFolder, FloatFolder, APFloat>(lhs, rhs, returnTy);
+}
+
 static bool isSplatZero(Type elemType, DenseElementsAttr val) {
   if (llvm::isa<FloatType>(elemType))
     return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
@@ -636,8 +687,8 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
   if (!lhsAttr || !rhsAttr)
     return {};
 
-  return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
-                                                            resultTy);
+  return arithmeticBinaryFolder<std::plus<APInt>, std::plus<APFloat>>(
+      lhsAttr, rhsAttr, resultTy);
 }
 
 OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
@@ -693,32 +744,96 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
 }
 
 namespace {
+
+// calculate lhs * rhs >> shift according to TOSA Spec
+// return nullopt if result is not in range of int32_t when shift > 0
+std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
+                            unsigned bitwidth) {
+  APInt result = lhs.sext(64) * rhs.sext(64);
+
+  if (shift > 0) {
+    auto round = APInt(64, 1) << (shift - 1);
+    result += round;
+    result.ashrInPlace(shift);
+    // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
+    if (!(result.getSExtValue() >= INT32_MIN &&
+          result.getSExtValue() <= INT32_MAX)) {
+      // REQUIRE failed
+      return std::nullopt;
+    }
+  }
+
+  return result.trunc(bitwidth);
+}
+
 DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
                                   RankedTensorType ty, int32_t shift) {
-  if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
-    if (llvm::isa<IntegerType>(ty.getElementType())) {
-      APInt l = lhs.getSplatValue<APInt>();
-      APInt r = rhs.getSplatValue<APInt>();
+  if (!lhs || !rhs)
+    return {};
+
+  // REQUIRE(0 <= shift && shift <= 63);
+  if (!(0 <= shift && shift <= 63))
+    return {};
 
-      if (shift == 0) {
-        return DenseElementsAttr::get(ty, l * r);
+  auto elementType = ty.getElementType();
+  if (!elementType.isIntOrFloat())
+    return {};
+
+  unsigned bitwidth = elementType.getIntOrFloatBitWidth();
+  // REQUIRE(in_t == int32_t || shift == 0);
+  if (!((llvm::isa<IntegerType>(elementType) && bitwidth == 32) || shift == 0))
+    return {};
+
+  if (rhs.isSplat() && lhs.isSplat()) {
+    if (llvm::isa<IntegerType>(elementType)) {
+      auto l = lhs.getSplatValue<APInt>();
+      auto r = rhs.getSplatValue<APInt>();
+
+      if (auto result = mulInt(l, r, shift, bitwidth)) {
+        return DenseElementsAttr::get(ty, result.value());
       }
+      // mulInt failed
+      return {};
+    }
 
-      auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
-      l = l.sext(bitwidth * 2);
-      r = r.sext(bitwidth * 2);
+    if (llvm::isa<FloatType>(elementType)) {
+      auto l = lhs.getSplatValue<APFloat>();
+      auto r = rhs.getSplatValue<APFloat>();
       auto result = l * r;
-      result.lshrInPlace(shift);
-      result = result.trunc(bitwidth);
       return DenseElementsAttr::get(ty, result);
     }
+  }
+
+  if (llvm::isa<IntegerType>(elementType)) {
+    auto lvalues = lhs.getValues<APInt>();
+    auto rvalues = rhs.getValues<APInt>();
+    if (lvalues.size() != rvalues.size()) {
+      return {};
+    }
+    SmallVector<APInt> results;
+    for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
+      if (auto result = mulInt(l, r, shift, bitwidth)) {
+        results.push_back(result.value());
+        continue;
+      }
+      // mulInt failed
+      return {};
+    }
+    return DenseElementsAttr::get(ty, results);
+  }
 
-    if (llvm::isa<FloatType>(ty.getElementType())) {
-      APFloat l = lhs.getSplatValue<APFloat>();
-      APFloat r = rhs.getSplatValue<APFloat>();
-      APFloat result = l * r;
-      return DenseElementsAttr::get(ty, result);
+  if (llvm::isa<FloatType>(elementType)) {
+    auto lvalues = lhs.getValues<APFloat>();
+    auto rvalues = rhs.getValues<APFloat>();
+    if (lvalues.size() != rvalues.size()) {
+      return {};
     }
+    SmallVector<APFloat> results;
+    for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
+      auto result = l * r;
+      results.push_back(result);
+    }
+    return DenseElementsAttr::get(ty, results);
   }
 
   return {};
@@ -793,8 +908,8 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
   if (!lhsAttr || !rhsAttr)
     return {};
 
-  return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
-                                                              resultTy);
+  return arithmeticBinaryFolder<std::minus<APInt>, std::minus<APFloat>>(
+      lhsAttr, rhsAttr, resultTy);
 }
 
 namespace {
@@ -835,7 +950,8 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
   if (!lhsAttr || !rhsAttr)
     return {};
 
-  return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
+  return comparisonBinaryFolder<APIntFoldGreater,
+                                ComparisonFold<std::greater<APFloat>>>(
       lhsAttr, rhsAttr, resultTy);
 }
 
@@ -849,8 +965,8 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
   if (!lhsAttr || !rhsAttr)
     return {};
 
-  return binaryFolder<APIntFoldGreaterEqual,
-                      ComparisonFold<std::greater_equal<APFloat>>>(
+  return comparisonBinaryFolder<APIntFoldGreaterEqual,
+                                ComparisonFold<std::greater_equal<APFloat>>>(
       lhsAttr, rhsAttr, resultTy);
 }
 
@@ -874,9 +990,9 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
   if (!lhsAttr || !rhsAttr)
     return {};
 
-  return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
-                      ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
-                                                              resultTy);
+  return comparisonBinaryFolder<ComparisonFold<std::equal_to<APInt>>,
+                                ComparisonFold<std::equal_to<APFloat>>>(
+      lhsAttr, rhsAttr, resultTy);
 }
 
 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index e6fb741df9598..5aab368fa044d 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -1092,11 +1092,8 @@ func.func @reduce_sum_constant() -> tensor<1x3xi32> {
 
 func.func @reduce_sum_constant() -> tensor<1x3xi32> {
   // CHECK-LABEL:     func.func @reduce_sum_constant() -> tensor<1x3xi32> {
-  // CHECK:           %[[VAL_0:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
-  // CHECK:           %[[VAL_1:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1, 2, 3], [4, 5, 7]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
-  // CHECK:           %[[VAL_2:.*]] = tosa.add %[[VAL_0]], %[[VAL_1]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
-  // CHECK:           %[[VAL_3:.*]] = tosa.reduce_sum %[[VAL_2]] {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
-  // CHECK:           return %[[VAL_3]] : tensor<1x3xi32>
+  // CHECK:           %[[K:.*]] = "tosa.const"() <{value = dense<{{\[\[}}10, 14, 19]]> : tensor<1x3xi32>}> : () -> tensor<1x3xi32>
+  // CHECK:           return %[[K]] : tensor<1x3xi32>
   %arg0 = "tosa.const"() <{value = dense<[[1,2,3], [4,5,6]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
   %arg1 = "tosa.const"() <{value = dense<[[1,2,3], [4,5,7]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
   %arg2 = tosa.add %arg0, %arg1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index 3ff3121348fca..fee1ce7793b12 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt --test-constant-fold %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --canonicalize --test-constant-fold %s | FileCheck %s
+
+// -----
 
 // CHECK-LABEL: func @test_const
 func.func @test_const(%arg0 : index) -> tensor<4xi32> {
@@ -7,6 +9,8 @@ func.func @test_const(%arg0 : index) -> tensor<4xi32> {
   return %0 : tensor<4xi32>
 }
 
+// -----
+
 // CHECK-LABEL: func @test_const_i64
 func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> {
   // CHECK: tosa.const
@@ -14,10 +18,218 @@ func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> {
   return %0 : tensor<4xi64>
 }
 
+// -----
+
 // CHECK-LABEL: func @try_fold_equal_with_unranked_tensor
-func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) {
+func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) -> tensor<*xi1> {
   // CHECK: tosa.equal
   // CHECK-NEXT: return
   %0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1>
-  return
+  return %0 : tensor<*xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_i32
+func.func @test_mul_i32() -> tensor<4xi32> {
+  // CHECK: %[[VAL:.+]] = "tosa.const"() <{value = dense<[9, 36, 36, 81]> : tensor<4xi32>}>
+  // CHECK: return %[[VAL]]
+  %lhs = "tosa.const"() {value = dense<[1, 2, -2, -3]> : tensor<4xi32>} : () -> tensor<4xi32>
+  %rhs = "tosa.const"() {value = dense<3> : tensor<4xi32>} : () -> tensor<4xi32>
+  %shift = "tosa.const"() { value = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
+  %x = tosa.mul %lhs, %rhs, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+  %y = tosa.mul %rhs, %lhs, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+  %result = tosa.mul %x, %y, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+
+  return %result : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_i32_shift
+func.func @test_mul_i32_shift() -> tensor<4xi32> {
+  // CHECK: %[[VAL:.+]] = "tosa.const"() <{value = dense<[2550, 8100, 2, 2025]> : tensor<4xi32>}>
+  // CHECK: return %[[VAL]]
+  %lhs = "tosa.const"() {value = dense<[135, 240, -4, -120]> : tensor<4xi32>} : () -> tensor<4xi32>
+  %rhs = "tosa.const"() {value = dense<3> : tensor<4xi32>} : () -> tensor<4xi32>
+  %shift = "tosa.const"() { value = dense<2> : tensor<1xi8> } : () -> tensor<1xi8>
+  %x = tosa.mul %lhs, %rhs, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+  %y = tosa.mul %rhs, %lhs, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+  %result = tosa.mul %x, %y, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+  return %result : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_f32
+func.func @test_mul_f32() -> tensor<4xf32> {
+  // CHECK: %[[VAL:.+]] = "tosa.const"() <{value = dense<[2.304000e+01, 58.9824028, 1.6384002, 14.7456007]> : tensor<4xf32>}>
+  // CHECK: return %[[VAL]]
+  %lhs = "tosa.const"() {value = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+  %rhs = "tosa.const"() {value = dense<3.2> : tensor<4xf32>} : () -> tensor<4xf32>
+  %shift = "tosa.const"() { value = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
+  %x = tosa.mul %lhs, %rhs, %shift : (tensor<4xf32>, tensor<4xf32>, tensor<1xi8>) -> tensor<4xf32>
+  %y = tosa.mul %rhs, %lhs, %shift : (tensor<4xf32>, tensor<4xf32>, tensor<1xi8>) -> tensor<4xf32>
+  %result = tosa.mul %x, %y, %shift : (tensor<4xf32>, tensor<4xf32>, tensor<1xi8>) -> tensor<4xf32>
+  return %result : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_add_f32
+func.func @test_add_f32() -> tensor<4xf32> {
+  // CHECK: %[[VAL:.+]] = "tosa.const"() <{value = dense<[7.500000e+00, 9.300000e+00, 3.69999981, 2.100000e+00]> : tensor<4xf32>}>
+  // CHECK: return %[[VAL]]
+  %cst = "tosa.const"() {value = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+  %splat1 = "tosa.const"() {value = dense<3.2> : tensor<4xf32>} : () -> tensor<4xf32>
+  %splat2 = "tosa.const"() {value = dense<1.3> : tensor<4xf32>} : () -> tensor<4xf32>
+  %x = tosa.add %cst, %splat1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  %y = tosa.add %splat2, %cst : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  %result = tosa.add %x, %y : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  return %result : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_add_i32
+func.func @test_add_i32() -> tensor<4xi32> {
+  // CHECK: %[[VAL:.+]] = "tosa.const"() <{value = dense<[75, 93, 37, 21]> : tensor<4xi32>}>
+  // CHECK: return %[[VAL]]
+  %cst = "tosa.const"() {value = dense<[15, 24, -4, -12]> : tensor<4xi32>} : () -> tensor<4xi32>
+  %splat1 = "tosa.const"() {value = dense<32> : tensor<4xi32>} : () -> tensor<4xi32>
+  %splat2 = "tosa.const"() {value = dense<13> : tensor<4xi32>} : () -> tensor<4xi32>
+  %x = tosa.add %cst, %splat1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+  %y = tosa.add %splat2, %cst : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+  %result = tosa.add %x, %y : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+  return %result : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_sub_f32
+func.func @test_sub_f32() -> tensor<4xf32> {
+  // CHECK: %[[VAL:.+]] = "tosa.const"() <{value = dense<[-1.500000e+00, 0.300000191, -5.300000e+00, -6.900000e+00]> : tensor<4xf32>}>
+  // CHECK: return %[[VAL]]
+  %cst = "tosa.const"() {value = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+  %splat1 = "tosa.const"() {value = dense<3.2> : tensor<4xf32>} : () -> tensor<4xf32>
+  %splat2 = "tosa.const"() {value = dense<1.3> : tensor<4xf32>} : () -> tensor<4xf32>
+  %x = tosa.sub %cst, %splat1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  %y = tosa.sub %splat2, %cst : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  %result = tosa.sub %x, %y : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+  return %result : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_sub_i32
+func.func @test_sub_i32() -> tensor<4xi32> {
+  // CHECK: %[[VAL:.+]] = "tosa.const"() <{value = dense<[-15, 3, -53, -69]> : tensor<4xi32>}>
+  // CHECK: return %[[VAL]]
+  %cst = "tosa.const"() {value = dense<[15, 24, -4, -12]> : tensor<4xi32>} : () -> tensor<4xi32>
+  %splat1 = "tosa.const"() {value = dense<32> : tensor<4xi32>} : () -> tensor<4xi32>
+  %splat2 = "tosa.const"() {value = dense<13> : tensor<4xi32>} : () -> tensor<4xi32>
+  %x = tosa.sub %cst, %splat1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+  %y = tosa.sub %splat2, %cst : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+  %result = tosa.sub %x, %y : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+  return %result : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_greater_f32
+func.func @test_greater_f32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
+  // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{value = dense<[false, true, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+  // CHECK: %[[VAL_1:.+]] = "tosa.const"() <{value = dense<[false, false, true, true]> : tensor<4xi1>}> : () -> tensor<4xi1>
+  // CHECK: %[[VAL_2:.+]] = "tosa.const"() <{value = dense<[false, true, true, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+  // CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]]
+  %cst1 = "tosa.const"() {value = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+  %splat = "tosa.const"() {value = dense<1.5> : tensor<4xf32>} : () -> tensor<4xf32>
+  %cst2 = "tosa.const"() {value = dense<[1.7, 2.3, -0.5, -1.1]> : tensor<4xf32>} : () -> tensor<4xf32>
+  %x = tosa.greater %cst1, %splat : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+  %y = tosa.greater %splat, %cst1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+  %z = tosa.greater %cst1, %cst2 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+  return %x, %y, %z : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_greater_i32
+func.func @test_greater_i32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
+  // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{value = dense<[false, true, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+  // CHECK: %[[VAL_1:.+]] = "tosa.const"() <{value = dense<[false, false, true, true]> : tensor<4xi1>}> : () -> tensor<4xi1>
+  // CHECK: %[[VAL_2:.+]] = "tosa.const"() <{value = dense<[false, true, true, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+  // CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]]
+  %cst1 = "tosa.const"() {value = dense<[15, 24, -4, -12]> : tensor<4xi32>} : () -> tensor<4xi32>
+  %cst2 = "tosa.const"() {value = dense<[17, 23, -5, -11]> : tensor<4xi32>} : () -> tensor<4xi32>
+  %splat = "tosa.const"() {value = dense<15> : tensor<4xi32>} : () -> tensor<4xi32>
+  %x = tosa.greater %cst1, %splat : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  %y = tosa.greater %splat, %cst1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  %z = tosa.greater %cst1, %cst2 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  return %x, %y, %z : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_greater_equal_f32
+func.func @test_greater_equal_f32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
+  // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{value = dense<[true, true, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+  // CHECK: %[[VAL_1:.+]] = "tosa.const"() <{value = dense<[true, false, true, true]> : tensor<4xi1>}> : () -> tensor<4xi1>
+  // CHECK: %[[VAL_2:.+]] = "tosa.const"() <{value = dense<[true, true, true, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+  // CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]]
+  %cst1 = "tosa.const"() {value = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+  %splat = "tosa.const"() {value = dense<1.5> : tensor<4xf32>} : () -> tensor<4xf32>
+  %cst2 = "tosa.const"() {value = dense<[1.4, 2.4, -0.5, -1.1]> : tensor<4xf32>} : () -> tensor<4xf32>
+  %x = tosa.greater_equal %cst1, %splat : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+  %y = tosa.greater_equal %splat, %cst1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+  %z = tosa.greater_equal %cst1, %cst2 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+  return %x, %y, %z : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_greater_equal_i32
+func.func @test_greater_equal_i32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
+  // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{value = dense<[false, true, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+  // CHECK: %[[VAL_1:.+]] = "tosa.const"() <{value = dense<[true, false, true, true]> : tensor<4xi1>}> : () -> tensor<4xi1>
+  // CHECK: %[[VAL_2:.+]] = "tosa.const"() <{value = dense<[true, true, true, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+  // CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]]
+  %cst1 = "tosa.const"() {value = dense<[15, 24, -4, -12]> : tensor<4xi32>} : () -> tensor<4xi32>
+  %splat = "tosa.const"() {value = dense<16> : tensor<4xi32>} : () -> tensor<4xi32>
+  %cst2 = "tosa.const"() {value = dense<[14, 24, -5, -11]> : tensor<4xi32>} : () -> tensor<4xi32>
+  %x = tosa.greater_equal %cst1, %splat : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  %y = tosa.greater_equal %splat, %cst1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  %z = tosa.greater_equal %cst1, %cst2 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  return %x, %y, %z : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_equal_f32
+func.func @test_equal_f32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
+  // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{value = dense<[true, false, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+  // CHECK: %[[VAL_1:.+]] = "tosa.const"() <{value = dense<[false, true, false, true]> : tensor<4xi1>}> : () -> tensor<4xi1>
+  // CHECK: return %[[VAL_0]], %[[VAL_0]], %[[VAL_1]]
+  %cst1 = "tosa.const"() {value = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+  %splat = "tosa.const"() {value = dense<1.5> : tensor<4xf32>} : () -> tensor<4xf32>
+  %cst2 = "tosa.const"() {value = dense<[1.4, 2.4, -0.5, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+  %x = tosa.equal %cst1, %splat : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+  %y = tosa.equal %splat, %cst1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+  %z = tosa.equal %cst1, %cst2 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+  return %x, %y, %z : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_equal_i32
+func.func @test_equal_i32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
+  // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{value = dense<[true, false, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+  // CHECK: %[[VAL_1:.+]] = "tosa.const"() <{value = dense<[false, true, false, true]> : tensor<4xi1>}> : () -> tensor<4xi1>
+  // CHECK: return %[[VAL_0]], %[[VAL_0]], %[[VAL_1]]
+  %cst1 = "tosa.const"() {value = dense<[15, 24, -4, -12]> : tensor<4xi32>} : () -> tensor<4xi32>
+  %splat = "tosa.const"() {value = dense<15> : tensor<4xi32>} : () -> tensor<4xi32>
+  %cst2 = "tosa.const"() {value = dense<[14, 24, -5, -12]> : tensor<4xi32>} : () -> tensor<4xi32>
+  %x = tosa.equal %cst1, %splat : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  %y = tosa.equal %splat, %cst1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  %z = tosa.equal %cst1, %cst2 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+  return %x, %y, %z : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
 }



More information about the Mlir-commits mailing list