[Mlir-commits] [mlir] [mlir][tosa] Enhance folder for Tosa binary operators (PR #128059)
Tai Ly
llvmlistbot at llvm.org
Thu Feb 20 12:14:18 PST 2025
https://github.com/Tai78641 created https://github.com/llvm/llvm-project/pull/128059
This enhances folder for tosa binary operators to support non-splat constant attributes for following ops:
- mul
- add
- sub
- greater
- greater_equal
- equal
>From c4bbd39bd588bc2038d435aff4d85c89712ef3cf Mon Sep 17 00:00:00 2001
From: Tai Ly <tai.ly at arm.com>
Date: Tue, 4 Jun 2024 18:01:33 +0000
Subject: [PATCH] [mlir][tosa] Enhance folder for Tosa binary operators
This enhances folder for tosa binary operators to support non-splat constant attributes for
following ops:
- mul
- add
- sub
- greater
- greater_equal
- equal
Signed-off-by: Tai Ly <tai.ly at arm.com>
Change-Id: I3198a808988a71b5894d8f7c410b407340564c38
---
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 180 ++++++++++++---
mlir/test/Dialect/Tosa/constant-op-fold.mlir | 7 +-
mlir/test/Dialect/Tosa/constant_folding.mlir | 218 +++++++++++++++++-
3 files changed, 365 insertions(+), 40 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9bfc2aae1d6a5..2c6c6e2ed284c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -563,15 +563,21 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
// Operator Folders.
//===----------------------------------------------------------------------===//
-template <typename IntFolder, typename FloatFolder>
+template <typename IntFolder, typename FloatFolder, typename FloatResultAPType>
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
RankedTensorType returnTy) {
- if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
- auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
- auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
- if (lETy != rETy)
- return {};
+ if (!rhs || !lhs)
+ return {};
+
+ auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
+ auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
+ if (lETy != rETy)
+ return {};
+
+ if (!lETy.isIntOrFloat())
+ return {};
+ if (rhs.isSplat() && lhs.isSplat()) {
if (llvm::isa<IntegerType>(lETy)) {
APInt l = lhs.getSplatValue<APInt>();
APInt r = rhs.getSplatValue<APInt>();
@@ -587,9 +593,54 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
}
}
+ if (llvm::isa<IntegerType>(lETy)) {
+ auto lvalues = lhs.getValues<APInt>();
+ auto rvalues = rhs.getValues<APInt>();
+ if (lvalues.size() != rvalues.size()) {
+ return {};
+ }
+ SmallVector<APInt> results;
+ for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
+ auto result = IntFolder()(l, r);
+ results.push_back(result);
+ }
+ return DenseElementsAttr::get(returnTy, results);
+ }
+
+ if (llvm::isa<FloatType>(lETy)) {
+ auto lvalues = lhs.getValues<APFloat>();
+ auto rvalues = rhs.getValues<APFloat>();
+ if (lvalues.size() != rvalues.size()) {
+ return {};
+ }
+ // FloatFolder() may return either APFloat or APInt (comparison functions)
+ SmallVector<FloatResultAPType> results;
+ for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
+ auto result = FloatFolder()(l, r);
+ results.push_back(result);
+ }
+ return DenseElementsAttr::get(returnTy, results);
+ }
+
return {};
}
+template <typename IntFolder, typename FloatFolder>
+DenseElementsAttr comparisonBinaryFolder(DenseElementsAttr lhs,
+ DenseElementsAttr rhs,
+ RankedTensorType returnTy) {
+ // comparison FloatFolder() functions return APInt values
+ return binaryFolder<IntFolder, FloatFolder, APInt>(lhs, rhs, returnTy);
+}
+
+template <typename IntFolder, typename FloatFolder>
+DenseElementsAttr arithmeticBinaryFolder(DenseElementsAttr lhs,
+ DenseElementsAttr rhs,
+ RankedTensorType returnTy) {
+ // arithmetic FloatFolder() functions return APFloat values
+ return binaryFolder<IntFolder, FloatFolder, APFloat>(lhs, rhs, returnTy);
+}
+
static bool isSplatZero(Type elemType, DenseElementsAttr val) {
if (llvm::isa<FloatType>(elemType))
return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
@@ -636,8 +687,8 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
- resultTy);
+ return arithmeticBinaryFolder<std::plus<APInt>, std::plus<APFloat>>(
+ lhsAttr, rhsAttr, resultTy);
}
OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
@@ -693,32 +744,96 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
}
namespace {
+
+// calculate lhs * rhs >> shift according to TOSA Spec
+// return nullopt if result is not in range of int32_t when shift > 0
+std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
+ unsigned bitwidth) {
+ APInt result = lhs.sext(64) * rhs.sext(64);
+
+ if (shift > 0) {
+ auto round = APInt(64, 1) << (shift - 1);
+ result += round;
+ result.ashrInPlace(shift);
+ // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
+ if (!(result.getSExtValue() >= INT32_MIN &&
+ result.getSExtValue() <= INT32_MAX)) {
+ // REQUIRE failed
+ return std::nullopt;
+ }
+ }
+
+ return result.trunc(bitwidth);
+}
+
DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
RankedTensorType ty, int32_t shift) {
- if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
- if (llvm::isa<IntegerType>(ty.getElementType())) {
- APInt l = lhs.getSplatValue<APInt>();
- APInt r = rhs.getSplatValue<APInt>();
+ if (!lhs || !rhs)
+ return {};
+
+ // REQUIRE(0 <= shift && shift <= 63);
+ if (!(0 <= shift && shift <= 63))
+ return {};
- if (shift == 0) {
- return DenseElementsAttr::get(ty, l * r);
+ auto elementType = ty.getElementType();
+ if (!elementType.isIntOrFloat())
+ return {};
+
+ unsigned bitwidth = elementType.getIntOrFloatBitWidth();
+ // REQUIRE(in_t == int32_t || shift == 0);
+ if (!((llvm::isa<IntegerType>(elementType) && bitwidth == 32) || shift == 0))
+ return {};
+
+ if (rhs.isSplat() && lhs.isSplat()) {
+ if (llvm::isa<IntegerType>(elementType)) {
+ auto l = lhs.getSplatValue<APInt>();
+ auto r = rhs.getSplatValue<APInt>();
+
+ if (auto result = mulInt(l, r, shift, bitwidth)) {
+ return DenseElementsAttr::get(ty, result.value());
}
+ // mulInt failed
+ return {};
+ }
- auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
- l = l.sext(bitwidth * 2);
- r = r.sext(bitwidth * 2);
+ if (llvm::isa<FloatType>(elementType)) {
+ auto l = lhs.getSplatValue<APFloat>();
+ auto r = rhs.getSplatValue<APFloat>();
auto result = l * r;
- result.lshrInPlace(shift);
- result = result.trunc(bitwidth);
return DenseElementsAttr::get(ty, result);
}
+ }
+
+ if (llvm::isa<IntegerType>(elementType)) {
+ auto lvalues = lhs.getValues<APInt>();
+ auto rvalues = rhs.getValues<APInt>();
+ if (lvalues.size() != rvalues.size()) {
+ return {};
+ }
+ SmallVector<APInt> results;
+ for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
+ if (auto result = mulInt(l, r, shift, bitwidth)) {
+ results.push_back(result.value());
+ continue;
+ }
+ // mulInt failed
+ return {};
+ }
+ return DenseElementsAttr::get(ty, results);
+ }
- if (llvm::isa<FloatType>(ty.getElementType())) {
- APFloat l = lhs.getSplatValue<APFloat>();
- APFloat r = rhs.getSplatValue<APFloat>();
- APFloat result = l * r;
- return DenseElementsAttr::get(ty, result);
+ if (llvm::isa<FloatType>(elementType)) {
+ auto lvalues = lhs.getValues<APFloat>();
+ auto rvalues = rhs.getValues<APFloat>();
+ if (lvalues.size() != rvalues.size()) {
+ return {};
}
+ SmallVector<APFloat> results;
+ for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
+ auto result = l * r;
+ results.push_back(result);
+ }
+ return DenseElementsAttr::get(ty, results);
}
return {};
@@ -793,8 +908,8 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
- resultTy);
+ return arithmeticBinaryFolder<std::minus<APInt>, std::minus<APFloat>>(
+ lhsAttr, rhsAttr, resultTy);
}
namespace {
@@ -835,7 +950,8 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
+ return comparisonBinaryFolder<APIntFoldGreater,
+ ComparisonFold<std::greater<APFloat>>>(
lhsAttr, rhsAttr, resultTy);
}
@@ -849,8 +965,8 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<APIntFoldGreaterEqual,
- ComparisonFold<std::greater_equal<APFloat>>>(
+ return comparisonBinaryFolder<APIntFoldGreaterEqual,
+ ComparisonFold<std::greater_equal<APFloat>>>(
lhsAttr, rhsAttr, resultTy);
}
@@ -874,9 +990,9 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
- ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
- resultTy);
+ return comparisonBinaryFolder<ComparisonFold<std::equal_to<APInt>>,
+ ComparisonFold<std::equal_to<APFloat>>>(
+ lhsAttr, rhsAttr, resultTy);
}
OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index e6fb741df9598..5aab368fa044d 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -1092,11 +1092,8 @@ func.func @reduce_sum_constant() -> tensor<1x3xi32> {
func.func @reduce_sum_constant() -> tensor<1x3xi32> {
// CHECK-LABEL: func.func @reduce_sum_constant() -> tensor<1x3xi32> {
- // CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
- // CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1, 2, 3], [4, 5, 7]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
- // CHECK: %[[VAL_2:.*]] = tosa.add %[[VAL_0]], %[[VAL_1]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
- // CHECK: %[[VAL_3:.*]] = tosa.reduce_sum %[[VAL_2]] {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
- // CHECK: return %[[VAL_3]] : tensor<1x3xi32>
+ // CHECK: %[[K:.*]] = "tosa.const"() <{value = dense<{{\[\[}}10, 14, 19]]> : tensor<1x3xi32>}> : () -> tensor<1x3xi32>
+ // CHECK: return %[[K]] : tensor<1x3xi32>
%arg0 = "tosa.const"() <{value = dense<[[1,2,3], [4,5,6]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
%arg1 = "tosa.const"() <{value = dense<[[1,2,3], [4,5,7]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
%arg2 = tosa.add %arg0, %arg1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index 3ff3121348fca..fee1ce7793b12 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt --test-constant-fold %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --canonicalize --test-constant-fold %s | FileCheck %s
+
+// -----
// CHECK-LABEL: func @test_const
func.func @test_const(%arg0 : index) -> tensor<4xi32> {
@@ -7,6 +9,8 @@ func.func @test_const(%arg0 : index) -> tensor<4xi32> {
return %0 : tensor<4xi32>
}
+// -----
+
// CHECK-LABEL: func @test_const_i64
func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> {
// CHECK: tosa.const
@@ -14,10 +18,218 @@ func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> {
return %0 : tensor<4xi64>
}
+// -----
+
// CHECK-LABEL: func @try_fold_equal_with_unranked_tensor
-func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) {
+func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) -> tensor<*xi1> {
// CHECK: tosa.equal
// CHECK-NEXT: return
%0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1>
- return
+ return %0 : tensor<*xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_i32
+func.func @test_mul_i32() -> tensor<4xi32> {
+ // CHECK: %[[VAL:.+]] = "tosa.const"() <{value = dense<[9, 36, 36, 81]> : tensor<4xi32>}>
+ // CHECK: return %[[VAL]]
+ %lhs = "tosa.const"() {value = dense<[1, 2, -2, -3]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %rhs = "tosa.const"() {value = dense<3> : tensor<4xi32>} : () -> tensor<4xi32>
+ %shift = "tosa.const"() { value = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
+ %x = tosa.mul %lhs, %rhs, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+ %y = tosa.mul %rhs, %lhs, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+ %result = tosa.mul %x, %y, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+
+ return %result : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_i32_shift
+func.func @test_mul_i32_shift() -> tensor<4xi32> {
+ // CHECK: %[[VAL:.+]] = "tosa.const"() <{value = dense<[2550, 8100, 2, 2025]> : tensor<4xi32>}>
+ // CHECK: return %[[VAL]]
+ %lhs = "tosa.const"() {value = dense<[135, 240, -4, -120]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %rhs = "tosa.const"() {value = dense<3> : tensor<4xi32>} : () -> tensor<4xi32>
+ %shift = "tosa.const"() { value = dense<2> : tensor<1xi8> } : () -> tensor<1xi8>
+ %x = tosa.mul %lhs, %rhs, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+ %y = tosa.mul %rhs, %lhs, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+ %result = tosa.mul %x, %y, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+ return %result : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_f32
+func.func @test_mul_f32() -> tensor<4xf32> {
+ // CHECK: %[[VAL:.+]] = "tosa.const"() <{value = dense<[2.304000e+01, 58.9824028, 1.6384002, 14.7456007]> : tensor<4xf32>}>
+ // CHECK: return %[[VAL]]
+ %lhs = "tosa.const"() {value = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+ %rhs = "tosa.const"() {value = dense<3.2> : tensor<4xf32>} : () -> tensor<4xf32>
+ %shift = "tosa.const"() { value = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
+ %x = tosa.mul %lhs, %rhs, %shift : (tensor<4xf32>, tensor<4xf32>, tensor<1xi8>) -> tensor<4xf32>
+ %y = tosa.mul %rhs, %lhs, %shift : (tensor<4xf32>, tensor<4xf32>, tensor<1xi8>) -> tensor<4xf32>
+ %result = tosa.mul %x, %y, %shift : (tensor<4xf32>, tensor<4xf32>, tensor<1xi8>) -> tensor<4xf32>
+ return %result : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_add_f32
+func.func @test_add_f32() -> tensor<4xf32> {
+ // CHECK: %[[VAL:.+]] = "tosa.const"() <{value = dense<[7.500000e+00, 9.300000e+00, 3.69999981, 2.100000e+00]> : tensor<4xf32>}>
+ // CHECK: return %[[VAL]]
+ %cst = "tosa.const"() {value = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+ %splat1 = "tosa.const"() {value = dense<3.2> : tensor<4xf32>} : () -> tensor<4xf32>
+ %splat2 = "tosa.const"() {value = dense<1.3> : tensor<4xf32>} : () -> tensor<4xf32>
+ %x = tosa.add %cst, %splat1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ %y = tosa.add %splat2, %cst : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ %result = tosa.add %x, %y : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ return %result : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_add_i32
+func.func @test_add_i32() -> tensor<4xi32> {
+ // CHECK: %[[VAL:.+]] = "tosa.const"() <{value = dense<[75, 93, 37, 21]> : tensor<4xi32>}>
+ // CHECK: return %[[VAL]]
+ %cst = "tosa.const"() {value = dense<[15, 24, -4, -12]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %splat1 = "tosa.const"() {value = dense<32> : tensor<4xi32>} : () -> tensor<4xi32>
+ %splat2 = "tosa.const"() {value = dense<13> : tensor<4xi32>} : () -> tensor<4xi32>
+ %x = tosa.add %cst, %splat1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+ %y = tosa.add %splat2, %cst : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+ %result = tosa.add %x, %y : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+ return %result : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_sub_f32
+func.func @test_sub_f32() -> tensor<4xf32> {
+ // CHECK: %[[VAL:.+]] = "tosa.const"() <{value = dense<[-1.500000e+00, 0.300000191, -5.300000e+00, -6.900000e+00]> : tensor<4xf32>}>
+ // CHECK: return %[[VAL]]
+ %cst = "tosa.const"() {value = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+ %splat1 = "tosa.const"() {value = dense<3.2> : tensor<4xf32>} : () -> tensor<4xf32>
+ %splat2 = "tosa.const"() {value = dense<1.3> : tensor<4xf32>} : () -> tensor<4xf32>
+ %x = tosa.sub %cst, %splat1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ %y = tosa.sub %splat2, %cst : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ %result = tosa.sub %x, %y : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ return %result : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_sub_i32
+func.func @test_sub_i32() -> tensor<4xi32> {
+ // CHECK: %[[VAL:.+]] = "tosa.const"() <{value = dense<[-15, 3, -53, -69]> : tensor<4xi32>}>
+ // CHECK: return %[[VAL]]
+ %cst = "tosa.const"() {value = dense<[15, 24, -4, -12]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %splat1 = "tosa.const"() {value = dense<32> : tensor<4xi32>} : () -> tensor<4xi32>
+ %splat2 = "tosa.const"() {value = dense<13> : tensor<4xi32>} : () -> tensor<4xi32>
+ %x = tosa.sub %cst, %splat1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+ %y = tosa.sub %splat2, %cst : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+ %result = tosa.sub %x, %y : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+ return %result : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_greater_f32
+func.func @test_greater_f32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
+ // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{value = dense<[false, true, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: %[[VAL_1:.+]] = "tosa.const"() <{value = dense<[false, false, true, true]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: %[[VAL_2:.+]] = "tosa.const"() <{value = dense<[false, true, true, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]]
+ %cst1 = "tosa.const"() {value = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+ %splat = "tosa.const"() {value = dense<1.5> : tensor<4xf32>} : () -> tensor<4xf32>
+ %cst2 = "tosa.const"() {value = dense<[1.7, 2.3, -0.5, -1.1]> : tensor<4xf32>} : () -> tensor<4xf32>
+ %x = tosa.greater %cst1, %splat : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+ %y = tosa.greater %splat, %cst1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+ %z = tosa.greater %cst1, %cst2 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+ return %x, %y, %z : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_greater_i32
+func.func @test_greater_i32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
+ // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{value = dense<[false, true, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: %[[VAL_1:.+]] = "tosa.const"() <{value = dense<[false, false, true, true]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: %[[VAL_2:.+]] = "tosa.const"() <{value = dense<[false, true, true, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]]
+ %cst1 = "tosa.const"() {value = dense<[15, 24, -4, -12]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %cst2 = "tosa.const"() {value = dense<[17, 23, -5, -11]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %splat = "tosa.const"() {value = dense<15> : tensor<4xi32>} : () -> tensor<4xi32>
+ %x = tosa.greater %cst1, %splat : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ %y = tosa.greater %splat, %cst1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ %z = tosa.greater %cst1, %cst2 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ return %x, %y, %z : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_greater_equal_f32
+func.func @test_greater_equal_f32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
+ // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{value = dense<[true, true, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: %[[VAL_1:.+]] = "tosa.const"() <{value = dense<[true, false, true, true]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: %[[VAL_2:.+]] = "tosa.const"() <{value = dense<[true, true, true, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]]
+ %cst1 = "tosa.const"() {value = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+ %splat = "tosa.const"() {value = dense<1.5> : tensor<4xf32>} : () -> tensor<4xf32>
+ %cst2 = "tosa.const"() {value = dense<[1.4, 2.4, -0.5, -1.1]> : tensor<4xf32>} : () -> tensor<4xf32>
+ %x = tosa.greater_equal %cst1, %splat : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+ %y = tosa.greater_equal %splat, %cst1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+ %z = tosa.greater_equal %cst1, %cst2 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+ return %x, %y, %z : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_greater_equal_i32
+func.func @test_greater_equal_i32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
+ // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{value = dense<[false, true, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: %[[VAL_1:.+]] = "tosa.const"() <{value = dense<[true, false, true, true]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: %[[VAL_2:.+]] = "tosa.const"() <{value = dense<[true, true, true, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]]
+ %cst1 = "tosa.const"() {value = dense<[15, 24, -4, -12]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %splat = "tosa.const"() {value = dense<16> : tensor<4xi32>} : () -> tensor<4xi32>
+ %cst2 = "tosa.const"() {value = dense<[14, 24, -5, -11]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %x = tosa.greater_equal %cst1, %splat : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ %y = tosa.greater_equal %splat, %cst1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ %z = tosa.greater_equal %cst1, %cst2 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ return %x, %y, %z : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_equal_f32
+func.func @test_equal_f32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
+ // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{value = dense<[true, false, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: %[[VAL_1:.+]] = "tosa.const"() <{value = dense<[false, true, false, true]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: return %[[VAL_0]], %[[VAL_0]], %[[VAL_1]]
+ %cst1 = "tosa.const"() {value = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+ %splat = "tosa.const"() {value = dense<1.5> : tensor<4xf32>} : () -> tensor<4xf32>
+ %cst2 = "tosa.const"() {value = dense<[1.4, 2.4, -0.5, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+ %x = tosa.equal %cst1, %splat : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+ %y = tosa.equal %splat, %cst1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+ %z = tosa.equal %cst1, %cst2 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+ return %x, %y, %z : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_equal_i32
+func.func @test_equal_i32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
+ // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{value = dense<[true, false, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: %[[VAL_1:.+]] = "tosa.const"() <{value = dense<[false, true, false, true]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: return %[[VAL_0]], %[[VAL_0]], %[[VAL_1]]
+ %cst1 = "tosa.const"() {value = dense<[15, 24, -4, -12]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %splat = "tosa.const"() {value = dense<15> : tensor<4xi32>} : () -> tensor<4xi32>
+ %cst2 = "tosa.const"() {value = dense<[14, 24, -5, -12]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %x = tosa.equal %cst1, %splat : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ %y = tosa.equal %splat, %cst1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ %z = tosa.equal %cst1, %cst2 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ return %x, %y, %z : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
}
More information about the Mlir-commits
mailing list