[Mlir-commits] [mlir] [mlir][tosa] Enhance folder for Tosa binary operators (PR #128059)
Tai Ly
llvmlistbot at llvm.org
Thu Mar 6 19:31:42 PST 2025
https://github.com/Tai78641 updated https://github.com/llvm/llvm-project/pull/128059
>From 558851bf7f964cedc160e89413c2e7b6f025d3e6 Mon Sep 17 00:00:00 2001
From: Tai Ly <tai.ly at arm.com>
Date: Tue, 4 Jun 2024 18:01:33 +0000
Subject: [PATCH] [mlir][tosa] Enhance folder for Tosa binary operators
This enhances folder for tosa binary operators to support non-splat constant attributes for
following ops:
- mul
- add
- sub
- greater
- greater_equal
- equal
---
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 185 ++++++++++++---
mlir/test/Dialect/Tosa/constant-op-fold.mlir | 7 +-
mlir/test/Dialect/Tosa/constant_folding.mlir | 218 +++++++++++++++++-
3 files changed, 370 insertions(+), 40 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 3e99c1f717d09..1be5697955020 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -501,15 +501,21 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
// Operator Folders.
//===----------------------------------------------------------------------===//
-template <typename IntFolder, typename FloatFolder>
+template <typename IntFolder, typename FloatFolder, typename FloatResultAPType>
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
RankedTensorType returnTy) {
- if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
- auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
- auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
- if (lETy != rETy)
- return {};
+ if (!rhs || !lhs)
+ return {};
+
+ auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
+ auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
+ if (lETy != rETy)
+ return {};
+
+ if (!lETy.isIntOrFloat())
+ return {};
+ if (rhs.isSplat() && lhs.isSplat()) {
if (llvm::isa<IntegerType>(lETy)) {
APInt l = lhs.getSplatValue<APInt>();
APInt r = rhs.getSplatValue<APInt>();
@@ -525,9 +531,59 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
}
}
+ auto lhsCount = lhs.getNumElements();
+ auto rhsCount = rhs.getNumElements();
+ if (lhsCount != rhsCount)
+ return {};
+
+ // to prevent long compile time, skip if too many elements
+ if (lhsCount > 128)
+ return {};
+
+ if (llvm::isa<IntegerType>(lETy)) {
+ auto lvalues = lhs.getValues<APInt>();
+ auto rvalues = rhs.getValues<APInt>();
+ SmallVector<APInt> results;
+ IntFolder intFolder{};
+ for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
+ auto result = intFolder(l, r);
+ results.push_back(result);
+ }
+ return DenseElementsAttr::get(returnTy, results);
+ }
+
+ if (llvm::isa<FloatType>(lETy)) {
+ auto lvalues = lhs.getValues<APFloat>();
+ auto rvalues = rhs.getValues<APFloat>();
+ // FloatFolder() may return either APFloat or APInt (comparison functions)
+ SmallVector<FloatResultAPType> results;
+ FloatFolder floatFolder{};
+ for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
+ auto result = floatFolder(l, r);
+ results.push_back(result);
+ }
+ return DenseElementsAttr::get(returnTy, results);
+ }
+
return {};
}
+template <typename IntFolder, typename FloatFolder>
+DenseElementsAttr comparisonBinaryFolder(DenseElementsAttr lhs,
+ DenseElementsAttr rhs,
+ RankedTensorType returnTy) {
+ // comparison FloatFolder() functions return APInt values
+ return binaryFolder<IntFolder, FloatFolder, APInt>(lhs, rhs, returnTy);
+}
+
+template <typename IntFolder, typename FloatFolder>
+DenseElementsAttr arithmeticBinaryFolder(DenseElementsAttr lhs,
+ DenseElementsAttr rhs,
+ RankedTensorType returnTy) {
+ // arithmetic FloatFolder() functions return APFloat values
+ return binaryFolder<IntFolder, FloatFolder, APFloat>(lhs, rhs, returnTy);
+}
+
static bool isSplatZero(Type elemType, DenseElementsAttr val) {
if (llvm::isa<FloatType>(elemType))
return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
@@ -574,8 +630,8 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
- resultTy);
+ return arithmeticBinaryFolder<std::plus<APInt>, std::plus<APFloat>>(
+ lhsAttr, rhsAttr, resultTy);
}
OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
@@ -632,32 +688,96 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
}
namespace {
+
+// calculate lhs * rhs >> shift according to TOSA Spec
+// return nullopt if result is not in range of int32_t when shift > 0
+std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
+ unsigned bitwidth) {
+ APInt result = lhs.sext(64) * rhs.sext(64);
+
+ if (shift > 0) {
+ auto round = APInt(64, 1) << (shift - 1);
+ result += round;
+ result.ashrInPlace(shift);
+ // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
+ if (!(result.getSExtValue() >= INT32_MIN &&
+ result.getSExtValue() <= INT32_MAX)) {
+ // REQUIRE failed
+ return std::nullopt;
+ }
+ }
+
+ return result.trunc(bitwidth);
+}
+
DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
RankedTensorType ty, int32_t shift) {
- if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
- if (llvm::isa<IntegerType>(ty.getElementType())) {
- APInt l = lhs.getSplatValue<APInt>();
- APInt r = rhs.getSplatValue<APInt>();
+ if (!lhs || !rhs)
+ return {};
+
+ // REQUIRE(0 <= shift && shift <= 63);
+ if (!(0 <= shift && shift <= 63))
+ return {};
+
+ auto elementType = ty.getElementType();
+ if (!elementType.isIntOrFloat())
+ return {};
- if (shift == 0) {
- return DenseElementsAttr::get(ty, l * r);
+ unsigned bitwidth = elementType.getIntOrFloatBitWidth();
+ // REQUIRE(in_t == int32_t || shift == 0);
+ if (!((llvm::isa<IntegerType>(elementType) && bitwidth == 32) || shift == 0))
+ return {};
+
+ if (rhs.isSplat() && lhs.isSplat()) {
+ if (llvm::isa<IntegerType>(elementType)) {
+ auto l = lhs.getSplatValue<APInt>();
+ auto r = rhs.getSplatValue<APInt>();
+
+ if (auto result = mulInt(l, r, shift, bitwidth)) {
+ return DenseElementsAttr::get(ty, result.value());
}
+ // mulInt failed
+ return {};
+ }
- auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
- l = l.sext(bitwidth * 2);
- r = r.sext(bitwidth * 2);
+ if (llvm::isa<FloatType>(elementType)) {
+ auto l = lhs.getSplatValue<APFloat>();
+ auto r = rhs.getSplatValue<APFloat>();
auto result = l * r;
- result.lshrInPlace(shift);
- result = result.trunc(bitwidth);
return DenseElementsAttr::get(ty, result);
}
+ }
- if (llvm::isa<FloatType>(ty.getElementType())) {
- APFloat l = lhs.getSplatValue<APFloat>();
- APFloat r = rhs.getSplatValue<APFloat>();
- APFloat result = l * r;
- return DenseElementsAttr::get(ty, result);
+ if (llvm::isa<IntegerType>(elementType)) {
+ auto lvalues = lhs.getValues<APInt>();
+ auto rvalues = rhs.getValues<APInt>();
+ if (lvalues.size() != rvalues.size()) {
+ return {};
+ }
+ SmallVector<APInt> results;
+ for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
+ if (auto result = mulInt(l, r, shift, bitwidth)) {
+ results.push_back(result.value());
+ continue;
+ }
+ // mulInt failed
+ return {};
+ }
+ return DenseElementsAttr::get(ty, results);
+ }
+
+ if (llvm::isa<FloatType>(elementType)) {
+ auto lvalues = lhs.getValues<APFloat>();
+ auto rvalues = rhs.getValues<APFloat>();
+ if (lvalues.size() != rvalues.size()) {
+ return {};
}
+ SmallVector<APFloat> results;
+ for (const auto &[l, r] : llvm::zip(lvalues, rvalues)) {
+ auto result = l * r;
+ results.push_back(result);
+ }
+ return DenseElementsAttr::get(ty, results);
}
return {};
@@ -732,8 +852,8 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
- resultTy);
+ return arithmeticBinaryFolder<std::minus<APInt>, std::minus<APFloat>>(
+ lhsAttr, rhsAttr, resultTy);
}
namespace {
@@ -774,7 +894,8 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
+ return comparisonBinaryFolder<APIntFoldGreater,
+ ComparisonFold<std::greater<APFloat>>>(
lhsAttr, rhsAttr, resultTy);
}
@@ -788,8 +909,8 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<APIntFoldGreaterEqual,
- ComparisonFold<std::greater_equal<APFloat>>>(
+ return comparisonBinaryFolder<APIntFoldGreaterEqual,
+ ComparisonFold<std::greater_equal<APFloat>>>(
lhsAttr, rhsAttr, resultTy);
}
@@ -813,9 +934,9 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
- ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
- resultTy);
+ return comparisonBinaryFolder<ComparisonFold<std::equal_to<APInt>>,
+ ComparisonFold<std::equal_to<APFloat>>>(
+ lhsAttr, rhsAttr, resultTy);
}
OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 8ac1e177ae4d4..2f396519b4420 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -1082,11 +1082,8 @@ func.func @reduce_sum_constant() -> tensor<1x3xi32> {
func.func @reduce_sum_constant() -> tensor<1x3xi32> {
// CHECK-LABEL: func.func @reduce_sum_constant() -> tensor<1x3xi32> {
- // CHECK: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<{{\[\[}}1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
- // CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<{{\[\[}}1, 2, 3], [4, 5, 7]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
- // CHECK: %[[VAL_2:.*]] = tosa.add %[[VAL_0]], %[[VAL_1]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
- // CHECK: %[[VAL_3:.*]] = tosa.reduce_sum %[[VAL_2]] {axis = 0 : i32} : (tensor<2x3xi32>) -> tensor<1x3xi32>
- // CHECK: return %[[VAL_3]] : tensor<1x3xi32>
+ // CHECK: %[[K:.*]] = "tosa.const"() <{values = dense<{{\[\[}}10, 14, 19]]> : tensor<1x3xi32>}> : () -> tensor<1x3xi32>
+ // CHECK: return %[[K]] : tensor<1x3xi32>
%arg0 = "tosa.const"() <{values = dense<[[1,2,3], [4,5,6]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
%arg1 = "tosa.const"() <{values = dense<[[1,2,3], [4,5,7]]> : tensor<2x3xi32>}> : () -> tensor<2x3xi32>
%arg2 = tosa.add %arg0, %arg1 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index 9b6ccdb54c107..8ae8af75c3856 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt --test-constant-fold %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --canonicalize --test-constant-fold %s | FileCheck %s
+
+// -----
// CHECK-LABEL: func @test_const
func.func @test_const(%arg0 : index) -> tensor<4xi32> {
@@ -7,6 +9,8 @@ func.func @test_const(%arg0 : index) -> tensor<4xi32> {
return %0 : tensor<4xi32>
}
+// -----
+
// CHECK-LABEL: func @test_const_i64
func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> {
// CHECK: tosa.const
@@ -14,10 +18,218 @@ func.func @test_const_i64(%arg0 : index) -> tensor<4xi64> {
return %0 : tensor<4xi64>
}
+// -----
+
// CHECK-LABEL: func @try_fold_equal_with_unranked_tensor
-func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) {
+func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tensor<1xi32>) -> tensor<*xi1> {
// CHECK: tosa.equal
// CHECK-NEXT: return
%0 = tosa.equal %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi1>
- return
+ return %0 : tensor<*xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_i32
+func.func @test_mul_i32() -> tensor<4xi32> {
+ // CHECK: %[[VAL:.+]] = "tosa.const"() <{values = dense<[9, 36, 36, 81]> : tensor<4xi32>}>
+ // CHECK: return %[[VAL]]
+ %lhs = "tosa.const"() {values = dense<[1, 2, -2, -3]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %rhs = "tosa.const"() {values = dense<3> : tensor<4xi32>} : () -> tensor<4xi32>
+ %shift = "tosa.const"() { values = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
+ %x = tosa.mul %lhs, %rhs, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+ %y = tosa.mul %rhs, %lhs, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+ %result = tosa.mul %x, %y, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+
+ return %result : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_i32_shift
+func.func @test_mul_i32_shift() -> tensor<4xi32> {
+ // CHECK: %[[VAL:.+]] = "tosa.const"() <{values = dense<[2550, 8100, 2, 2025]> : tensor<4xi32>}>
+ // CHECK: return %[[VAL]]
+ %lhs = "tosa.const"() {values = dense<[135, 240, -4, -120]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %rhs = "tosa.const"() {values = dense<3> : tensor<4xi32>} : () -> tensor<4xi32>
+ %shift = "tosa.const"() { values = dense<2> : tensor<1xi8> } : () -> tensor<1xi8>
+ %x = tosa.mul %lhs, %rhs, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+ %y = tosa.mul %rhs, %lhs, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+ %result = tosa.mul %x, %y, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
+ return %result : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_mul_f32
+func.func @test_mul_f32() -> tensor<4xf32> {
+ // CHECK: %[[VAL:.+]] = "tosa.const"() <{values = dense<[2.304000e+01, 58.9824028, 1.6384002, 14.7456007]> : tensor<4xf32>}>
+ // CHECK: return %[[VAL]]
+ %lhs = "tosa.const"() {values = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+ %rhs = "tosa.const"() {values = dense<3.2> : tensor<4xf32>} : () -> tensor<4xf32>
+ %shift = "tosa.const"() { values = dense<0> : tensor<1xi8> } : () -> tensor<1xi8>
+ %x = tosa.mul %lhs, %rhs, %shift : (tensor<4xf32>, tensor<4xf32>, tensor<1xi8>) -> tensor<4xf32>
+ %y = tosa.mul %rhs, %lhs, %shift : (tensor<4xf32>, tensor<4xf32>, tensor<1xi8>) -> tensor<4xf32>
+ %result = tosa.mul %x, %y, %shift : (tensor<4xf32>, tensor<4xf32>, tensor<1xi8>) -> tensor<4xf32>
+ return %result : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_add_f32
+func.func @test_add_f32() -> tensor<4xf32> {
+ // CHECK: %[[VAL:.+]] = "tosa.const"() <{values = dense<[7.500000e+00, 9.300000e+00, 3.69999981, 2.100000e+00]> : tensor<4xf32>}>
+ // CHECK: return %[[VAL]]
+ %cst = "tosa.const"() {values = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+ %splat1 = "tosa.const"() {values = dense<3.2> : tensor<4xf32>} : () -> tensor<4xf32>
+ %splat2 = "tosa.const"() {values = dense<1.3> : tensor<4xf32>} : () -> tensor<4xf32>
+ %x = tosa.add %cst, %splat1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ %y = tosa.add %splat2, %cst : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ %result = tosa.add %x, %y : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ return %result : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_add_i32
+func.func @test_add_i32() -> tensor<4xi32> {
+ // CHECK: %[[VAL:.+]] = "tosa.const"() <{values = dense<[75, 93, 37, 21]> : tensor<4xi32>}>
+ // CHECK: return %[[VAL]]
+ %cst = "tosa.const"() {values = dense<[15, 24, -4, -12]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %splat1 = "tosa.const"() {values = dense<32> : tensor<4xi32>} : () -> tensor<4xi32>
+ %splat2 = "tosa.const"() {values = dense<13> : tensor<4xi32>} : () -> tensor<4xi32>
+ %x = tosa.add %cst, %splat1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+ %y = tosa.add %splat2, %cst : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+ %result = tosa.add %x, %y : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+ return %result : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_sub_f32
+func.func @test_sub_f32() -> tensor<4xf32> {
+ // CHECK: %[[VAL:.+]] = "tosa.const"() <{values = dense<[-1.500000e+00, 0.300000191, -5.300000e+00, -6.900000e+00]> : tensor<4xf32>}>
+ // CHECK: return %[[VAL]]
+ %cst = "tosa.const"() {values = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+ %splat1 = "tosa.const"() {values = dense<3.2> : tensor<4xf32>} : () -> tensor<4xf32>
+ %splat2 = "tosa.const"() {values = dense<1.3> : tensor<4xf32>} : () -> tensor<4xf32>
+ %x = tosa.sub %cst, %splat1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ %y = tosa.sub %splat2, %cst : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ %result = tosa.sub %x, %y : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
+ return %result : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_sub_i32
+func.func @test_sub_i32() -> tensor<4xi32> {
+ // CHECK: %[[VAL:.+]] = "tosa.const"() <{values = dense<[-15, 3, -53, -69]> : tensor<4xi32>}>
+ // CHECK: return %[[VAL]]
+ %cst = "tosa.const"() {values = dense<[15, 24, -4, -12]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %splat1 = "tosa.const"() {values = dense<32> : tensor<4xi32>} : () -> tensor<4xi32>
+ %splat2 = "tosa.const"() {values = dense<13> : tensor<4xi32>} : () -> tensor<4xi32>
+ %x = tosa.sub %cst, %splat1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+ %y = tosa.sub %splat2, %cst : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+ %result = tosa.sub %x, %y : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
+ return %result : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_greater_f32
+func.func @test_greater_f32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
+ // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{values = dense<[false, true, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: %[[VAL_1:.+]] = "tosa.const"() <{values = dense<[false, false, true, true]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: %[[VAL_2:.+]] = "tosa.const"() <{values = dense<[false, true, true, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]]
+ %cst1 = "tosa.const"() {values = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+ %splat = "tosa.const"() {values = dense<1.5> : tensor<4xf32>} : () -> tensor<4xf32>
+ %cst2 = "tosa.const"() {values = dense<[1.7, 2.3, -0.5, -1.1]> : tensor<4xf32>} : () -> tensor<4xf32>
+ %x = tosa.greater %cst1, %splat : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+ %y = tosa.greater %splat, %cst1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+ %z = tosa.greater %cst1, %cst2 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+ return %x, %y, %z : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_greater_i32
+func.func @test_greater_i32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
+ // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{values = dense<[false, true, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: %[[VAL_1:.+]] = "tosa.const"() <{values = dense<[false, false, true, true]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: %[[VAL_2:.+]] = "tosa.const"() <{values = dense<[false, true, true, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]]
+ %cst1 = "tosa.const"() {values = dense<[15, 24, -4, -12]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %cst2 = "tosa.const"() {values = dense<[17, 23, -5, -11]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %splat = "tosa.const"() {values = dense<15> : tensor<4xi32>} : () -> tensor<4xi32>
+ %x = tosa.greater %cst1, %splat : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ %y = tosa.greater %splat, %cst1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ %z = tosa.greater %cst1, %cst2 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ return %x, %y, %z : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_greater_equal_f32
+func.func @test_greater_equal_f32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
+ // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{values = dense<[true, true, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: %[[VAL_1:.+]] = "tosa.const"() <{values = dense<[true, false, true, true]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: %[[VAL_2:.+]] = "tosa.const"() <{values = dense<[true, true, true, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]]
+ %cst1 = "tosa.const"() {values = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+ %splat = "tosa.const"() {values = dense<1.5> : tensor<4xf32>} : () -> tensor<4xf32>
+ %cst2 = "tosa.const"() {values = dense<[1.4, 2.4, -0.5, -1.1]> : tensor<4xf32>} : () -> tensor<4xf32>
+ %x = tosa.greater_equal %cst1, %splat : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+ %y = tosa.greater_equal %splat, %cst1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+ %z = tosa.greater_equal %cst1, %cst2 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+ return %x, %y, %z : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_greater_equal_i32
+func.func @test_greater_equal_i32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
+ // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{values = dense<[false, true, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: %[[VAL_1:.+]] = "tosa.const"() <{values = dense<[true, false, true, true]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: %[[VAL_2:.+]] = "tosa.const"() <{values = dense<[true, true, true, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]]
+ %cst1 = "tosa.const"() {values = dense<[15, 24, -4, -12]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %splat = "tosa.const"() {values = dense<16> : tensor<4xi32>} : () -> tensor<4xi32>
+ %cst2 = "tosa.const"() {values = dense<[14, 24, -5, -11]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %x = tosa.greater_equal %cst1, %splat : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ %y = tosa.greater_equal %splat, %cst1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ %z = tosa.greater_equal %cst1, %cst2 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ return %x, %y, %z : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_equal_f32
+func.func @test_equal_f32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
+ // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{values = dense<[true, false, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: %[[VAL_1:.+]] = "tosa.const"() <{values = dense<[false, true, false, true]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: return %[[VAL_0]], %[[VAL_0]], %[[VAL_1]]
+ %cst1 = "tosa.const"() {values = dense<[1.5, 2.4, -0.4, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+ %splat = "tosa.const"() {values = dense<1.5> : tensor<4xf32>} : () -> tensor<4xf32>
+ %cst2 = "tosa.const"() {values = dense<[1.4, 2.4, -0.5, -1.2]> : tensor<4xf32>} : () -> tensor<4xf32>
+ %x = tosa.equal %cst1, %splat : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+ %y = tosa.equal %splat, %cst1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+ %z = tosa.equal %cst1, %cst2 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
+ return %x, %y, %z : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: test_equal_i32
+func.func @test_equal_i32() -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
+ // CHECK: %[[VAL_0:.+]] = "tosa.const"() <{values = dense<[true, false, false, false]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: %[[VAL_1:.+]] = "tosa.const"() <{values = dense<[false, true, false, true]> : tensor<4xi1>}> : () -> tensor<4xi1>
+ // CHECK: return %[[VAL_0]], %[[VAL_0]], %[[VAL_1]]
+ %cst1 = "tosa.const"() {values = dense<[15, 24, -4, -12]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %splat = "tosa.const"() {values = dense<15> : tensor<4xi32>} : () -> tensor<4xi32>
+ %cst2 = "tosa.const"() {values = dense<[14, 24, -5, -12]> : tensor<4xi32>} : () -> tensor<4xi32>
+ %x = tosa.equal %cst1, %splat : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ %y = tosa.equal %splat, %cst1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ %z = tosa.equal %cst1, %cst2 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
+ return %x, %y, %z : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
}
More information about the Mlir-commits
mailing list