[Mlir-commits] [mlir] [mlir][tosa][tosa-to-linalg] Add NaN Mode Lowering (PR #125668)
Jack Frankland
llvmlistbot at llvm.org
Thu Feb 20 08:18:48 PST 2025
https://github.com/FranklandJack updated https://github.com/llvm/llvm-project/pull/125668
>From 7c747c02d1c18110a60c7b72919001b50280b0c3 Mon Sep 17 00:00:00 2001
From: Jack Frankland <jack.frankland at arm.com>
Date: Mon, 3 Feb 2025 18:36:09 +0000
Subject: [PATCH] [mlir][tosa][tosa-to-linalg] Add NaN Mode Lowering
Add support for NaN propagation lowering in the `tosa-to-linalg` and
`tosa-to-linalg-named` conversions by conditionally checking for NaN in
the case of ignore semantics and materializing the appropriate select
operations. Note that the default behviour of "propagate" matches
that of the arith dialect and so in that case we can avoid creating the
checks altogether.
Add appropriate lit tests including negative tests which check the
various comparisons and selects are materialized as appropriate.
This affects the following TOSA operators:
* arg_max
* max_pool_2d
* clamp
* reduce_max
* reduce_min
* maximum
* minimum
Signed-off-by: Jack Frankland <jack.frankland at arm.com>
---
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 211 +++++++++++++++++-
.../TosaToLinalg/TosaToLinalgNamed.cpp | 39 +++-
.../TosaToLinalg/tosa-to-linalg-named.mlir | 22 ++
.../TosaToLinalg/tosa-to-linalg.mlir | 192 ++++++++++++++++
4 files changed, 450 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 7b70b3ab8afc9..8c5a989d8d075 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -32,10 +32,47 @@
#include "llvm/ADT/Sequence.h"
#include <numeric>
+#include <type_traits>
using namespace mlir;
using namespace mlir::tosa;
+// Helper function to materialize the semantically correct compare and select
+// operations a binary operation with a specific NaN propagation mode.
+//
+// In the case of "PROPAGATE" semantics no compare and selection is required and
+// this function does nothing.
+//
+// In the case of "IGNORE" semantics this function materializes a comparison of
+// the current operands to the op which will return true for any NaN
+// argument and then selects between the non-NaN operation argument and the
+// calculated result based on whether the lhs or rhs is NaN or not. In pseudo
+// code:
+//
+// binary<op>(lhs, rhs):
+// result = op(lhs, rhs)
+// if lhs == NaN return rhs
+// if rhs == NaN return lhs
+// return result
+template <typename OpTy>
+static Value
+materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
+ Value lhs, Value rhs, Value result) {
+ auto nanMode = op.getNanMode();
+ if (nanMode == "PROPAGATE")
+ return result;
+
+ // Unordered comparison of NaN against itself will always return true.
+ Value lhsIsNaN = rewriter.create<arith::CmpFOp>(
+ op.getLoc(), arith::CmpFPredicate::UNO, lhs, lhs);
+ Value rhsIsNaN = rewriter.create<arith::CmpFOp>(
+ op.getLoc(), arith::CmpFPredicate::UNO, rhs, rhs);
+ Value rhsOrResult =
+ rewriter.create<arith::SelectOp>(op.getLoc(), lhsIsNaN, rhs, result);
+ return rewriter.create<arith::SelectOp>(op.getLoc(), rhsIsNaN, lhs,
+ rhsOrResult);
+}
+
template <typename T>
static arith::ConstantOp
createConstFromIntAttribute(Operation *op, const std::string &attrName,
@@ -367,7 +404,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// tosa::MaximumOp
if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
- return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
+ auto max = rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
+ return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MaximumOp>(op),
+ rewriter, args[0], args[1], max);
}
if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
@@ -376,7 +415,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// tosa::MinimumOp
if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
- return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
+ auto min = rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
+ return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MinimumOp>(op),
+ rewriter, args[0], args[1], min);
}
if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
@@ -404,7 +445,31 @@ static Value createLinalgBodyCalculationForElementwiseOp(
loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
auto max = rewriter.create<arith::ConstantOp>(
loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
- return clampFloatHelper(loc, args[0], min, max, rewriter);
+ auto result = clampFloatHelper(loc, args[0], min, max, rewriter);
+
+ auto clampOp = llvm::cast<tosa::ClampOp>(op);
+ const auto nanMode = clampOp.getNanMode();
+ // In the case of "PROPAGATE" semantics no compare and selection is
+ // required.
+ if (nanMode == "PROPAGATE")
+ return result;
+
+ // In the case of "IGNORE" semantics materialize a comparison
+ // of the current operand to the reduction which will return true for a NaN
+ // argument and then selects between the initial reduction value and the
+ // calculated result based on whether the argument is NaN or not. In pseudo
+ // code:
+ //
+ // reduce<op>(x, init):
+ // result = op(init, x)
+ // return init if x == NaN else result
+
+ // Unordered comparison of NaN against itself will always return true.
+ Value isNaN = rewriter.create<arith::CmpFOp>(
+ op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
+ // TOSA specifies that in "ignore" NaN mode the result is "min" if the input
+ // is NaN.
+ return rewriter.create<arith::SelectOp>(op->getLoc(), isNaN, min, result);
}
if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
@@ -1078,7 +1143,8 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
// Performs the match and rewrite for reduction operations. This includes
// declaring a correctly sized initial value, and the linalg.generic operation
// that reduces across the specified axis.
-static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
+template <typename OpTy>
+static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
PatternRewriter &rewriter) {
auto loc = op->getLoc();
auto inputTy = cast<ShapedType>(op->getOperand(0).getType());
@@ -1096,6 +1162,9 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
}
}
+ SmallVector<Value> inputs, outputs;
+ inputs.push_back(input);
+
// First fill the output buffer with the init value.
auto emptyTensor =
rewriter
@@ -1113,26 +1182,127 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
.create<linalg::FillOp>(loc, ValueRange{fillValue},
ValueRange{emptyTensor})
.result();
+ outputs.push_back(filledTensor);
+
+ bool isNanIgnoreMode = false;
+ if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
+ std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
+ if (op.getNanMode() == "IGNORE") {
+ isNanIgnoreMode = true;
+ // Because the TOSA spec requires the result be NaN iff all elements in
+ // the reduction are NaN we can't simply perform a compare and select.
+ // Additionally we have to keep track of whether we've seen any non-NaN
+ // values and then do a final select based on this predicate.
+ auto trueAttr = rewriter.getBoolAttr(true);
+ auto trueValue = rewriter.create<arith::ConstantOp>(loc, trueAttr);
+ auto emptyBoolTensor =
+ rewriter
+ .create<tensor::EmptyOp>(loc, reduceShape, trueValue.getType(),
+ dynDims)
+ .getResult();
+ auto allResultsNaNTensor =
+ rewriter
+ .create<linalg::FillOp>(loc, ValueRange{trueValue},
+ ValueRange{emptyBoolTensor})
+ .result();
+ // Note that because the linalg::ReduceOp has two variadic arguments
+ // (inputs and outputs) and it has the SameVariadicOperandSize trait we
+ // need to have the same number of inputs and outputs.
+ //
+ // The second input isn't actully used anywhere since the value used to
+ // update the NaN flag is calculated inside the body of the reduction and
+ // then used to update an out value.
+ // In order to satisfy type constraints we just pass another copy of the
+ // input here.
+ inputs.push_back(input);
+ outputs.push_back(allResultsNaNTensor);
+ }
+ }
bool didEncounterError = false;
- auto linalgOp = rewriter.create<linalg::ReduceOp>(
- loc, input, filledTensor, axis,
+ linalg::LinalgOp linalgOp = rewriter.create<linalg::ReduceOp>(
+ loc, inputs, outputs, axis,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
+ std::array<Value, 2> binaryArgs{
+ blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
auto result = createLinalgBodyCalculationForReduceOp(
- op, blockArgs, elementTy, rewriter);
+ op, binaryArgs, elementTy, rewriter);
if (result)
didEncounterError = true;
- nestedBuilder.create<linalg::YieldOp>(loc, result);
+ SmallVector<Value> resultsToYield;
+ if (isNanIgnoreMode) {
+ auto inputValue = blockArgs[0];
+ auto initialValue = blockArgs[2];
+ auto oldAllResultsNanFlagValue = blockArgs[3];
+
+ // Unordered comparison of NaN against itself will always return true.
+ Value isNaN = nestedBuilder.create<arith::CmpFOp>(
+ op->getLoc(), arith::CmpFPredicate::UNO, inputValue, inputValue);
+ // If we've encountered a NaN, take the non-NaN value.
+ auto selectOp = nestedBuilder.create<arith::SelectOp>(
+ op->getLoc(), isNaN, initialValue, result);
+ // Update the flag which keeps track of whether we have seen a non-NaN
+ // value.
+ auto newAllResultsNanFlagValue = nestedBuilder.create<arith::AndIOp>(
+ op->getLoc(), oldAllResultsNanFlagValue, isNaN);
+ resultsToYield.push_back(selectOp);
+ resultsToYield.push_back(newAllResultsNanFlagValue);
+ } else {
+ resultsToYield.push_back(result);
+ }
+ nestedBuilder.create<linalg::YieldOp>(loc, resultsToYield);
});
if (!didEncounterError)
return rewriter.notifyMatchFailure(
op, "unable to create linalg.generic body for reduce op");
+ if (isNanIgnoreMode) {
+ // Materialize a check to see whether we encountered any non-NaN values, if
+ // we didn't we need to select a tensor of NaNs since the result will just
+ // be the initial identity value propagated through all the compares and
+ // selects inside the reduction.
+
+ // Create a tensor full of NaNs.
+ auto nanValueAttr = rewriter.getFloatAttr(
+ elementTy,
+ APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
+ auto nanValue = rewriter.create<arith::ConstantOp>(loc, nanValueAttr);
+ auto emptyNanTensor =
+ rewriter
+ .create<tensor::EmptyOp>(loc, reduceShape,
+ resultTy.getElementType(), dynDims)
+ .getResult();
+ auto nanFilledTensor =
+ rewriter
+ .create<linalg::FillOp>(loc, ValueRange{nanValue},
+ ValueRange{emptyNanTensor})
+ .result();
+
+ // Create an empty tensor, non need to fill this since it will be
+ // overwritten by the select.
+ auto finalEmptyTensor =
+ rewriter
+ .create<tensor::EmptyOp>(loc, reduceShape,
+ resultTy.getElementType(), dynDims)
+ .getResult();
+
+ // Do a selection between the tensors akin to:
+ // result = NaN if "all results NaN" else result.
+ SmallVector<Value> ins, outs;
+ ins.push_back(linalgOp->getOpResult(1));
+ ins.push_back(nanFilledTensor);
+ ins.push_back(linalgOp->getResult(0));
+ outs.push_back(finalEmptyTensor);
+ auto linalgSelect =
+ rewriter.create<linalg::SelectOp>(op->getLoc(), ins, outs);
+ linalgOp = linalgSelect;
+ }
+
SmallVector<ReassociationExprs, 4> reassociationMap;
uint64_t expandInputRank =
- cast<ShapedType>(linalgOp.getResults()[0].getType()).getRank();
+ cast<ShapedType>(linalgOp->getResults()[0].getType()).getRank();
reassociationMap.resize(expandInputRank);
for (uint64_t i = 0; i < expandInputRank; i++) {
@@ -1151,7 +1321,7 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
// not have access to such information. This matters when handling dynamically
// sized tensors.
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
- op, resultTy, linalgOp.getResults()[0], reassociationMap);
+ op, resultTy, linalgOp->getResults()[0], reassociationMap);
return success();
}
@@ -2097,6 +2267,27 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
nestedLoc, predicate, newValue, oldValue);
auto resultIndex = rewriter.create<arith::SelectOp>(
nestedLoc, predicate, newIndex, oldIndex);
+
+ // Check if we need to materialize compare and select for the given
+ // NaN propagation mode.
+
+ // "PROPAGATE" matches the default NaN propagation mode of the arith
+ // dialect so no compare and select is required.
+ //
+ // In the case "IGNORE" we check if the current argument is NaN and
+ // select the old index and value otherwise take the updated index and
+ // value.
+ if (const auto nanMode = argmaxOp.getNanMode(); nanMode == "IGNORE") {
+ // Unordered comparison of NaN against itself will always return
+ // true.
+ Value isNaN = rewriter.create<arith::CmpFOp>(
+ argmaxOp.getLoc(), arith::CmpFPredicate::UNO, newValue,
+ newValue);
+ resultMax = rewriter.create<arith::SelectOp>(nestedLoc, isNaN,
+ oldValue, resultMax);
+ resultIndex = rewriter.create<arith::SelectOp>(
+ nestedLoc, isNaN, oldIndex, resultIndex);
+ }
nestedBuilder.create<linalg::YieldOp>(
nestedLoc, ValueRange({resultIndex, resultMax}));
});
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index a8fd536dd2548..55d07da1e3b55 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -724,11 +724,42 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(
op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
filledEmptyTensor, strideAttr, dilationAttr);
- } else {
- rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
- op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
- filledEmptyTensor, strideAttr, dilationAttr);
+ return llvm::success();
}
+
+ auto resultOp = rewriter.create<linalg::PoolingNhwcMaxOp>(
+ op->getLoc(), ArrayRef<Type>{resultTy},
+ ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr,
+ dilationAttr);
+
+ // "PROPAGATE" mode matches the behaviour of the LinAlg named op, so no
+ // compare and select materialization is required.
+ //
+ // In the case of "IGNORE" we need to insert a compare and select. Since
+ // we've already produced a named op we will just take its body and modify
+ // it to include the appropriate checks. If the current value is NaN the
+ // old value of pool will be taken otherwise we use the result.
+ if (const auto nanMode = op.getNanMode(); nanMode == "IGNORE") {
+ auto *block = resultOp.getBlock();
+ rewriter.setInsertionPointToEnd(block);
+
+ auto in = block->getArgument(0);
+ auto out = block->getArgument(2);
+
+ auto *oldYieldOp = &*block->rbegin();
+ auto result = oldYieldOp->getOperand(0);
+
+ Value isNaN = rewriter.create<arith::CmpFOp>(
+ op->getLoc(), arith::CmpFPredicate::UNO, in, in);
+
+ auto selectOp =
+ rewriter.create<arith::SelectOp>(op->getLoc(), isNaN, out, result);
+ auto newYieldOp = rewriter.create<linalg::YieldOp>(oldYieldOp->getLoc(),
+ selectOp.getResult());
+ rewriter.replaceOp(oldYieldOp, newYieldOp);
+ }
+
+ rewriter.replaceOp(op, resultOp);
return success();
}
};
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index a524359b49759..bb7b7be51191e 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s
// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{prefer-conv2d-kernel-layout-hwcf=true}))" %s -verify-diagnostics -o -| FileCheck --check-prefix="HWCF" %s
// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,cse))" %s -verify-diagnostics -o -| FileCheck --check-prefix="CHECK-CSE" %s
+// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,linalg-generalize-named-ops))" %s -verify-diagnostics -o -| FileCheck %s --check-prefix="CHECK-NAN"
// CHECK-LABEL: @matmul
func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) {
@@ -906,3 +907,24 @@ func.func @test_transpose_dyn_multiple_3d(%arg0: tensor<?x?x?xf32>) {
%1 = "tosa.transpose"(%arg0, %0) : (tensor<?x?x?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
return
}
+
+// -----
+
+// CHECK-NAN-LABEL: @nan_propagation_modes
+func.func @nan_propagation_modes(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x4x32x62xf32>, tensor<1x4x32x62xf32>) {
+ // CHECK-NAN: linalg.generic
+ // CHECK-NAN-NOT: arith.maximumf
+ // CHECK-NAN-NOT: arith.cmpf uno
+ // CHECK-NAN-NOT: arith.select
+ // CHECK-NAN: linalg.yield
+ %0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = "PROPAGATE"} : (tensor<1x6x34x62xf32>) -> tensor<1x4x32x62xf32>
+
+ // CHECK-NAN: linalg.generic
+ // CHECK-NAN: arith.maximumf
+ // CHECK-NAN: arith.cmpf uno
+ // CHECK-NAN: arith.select
+ // CHECK-NAN: linalg.yield
+ %1 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = "IGNORE"} : (tensor<1x6x34x62xf32>) -> tensor<1x4x32x62xf32>
+
+ return %0, %1 : tensor<1x4x32x62xf32>, tensor<1x4x32x62xf32>
+}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 17add2d41afe7..86e6f9ed9264b 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1992,3 +1992,195 @@ func.func @test_cast_fp32_i64(%arg0: tensor<1xf32>) -> (tensor<1xi64>) {
%0 = tosa.cast %arg0 : (tensor<1xf32>) -> tensor<1xi64>
return %0: tensor<1xi64>
}
+
+// -----
+
+// CHECK-LABEL: @reduce_min_nan_propagate
+func.func @reduce_min_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
+ // CHECK: linalg.reduce
+ // CHECK: arith.minimumf
+ // CHECK-NOT: arith.cmpf uno
+ // CHECK-NOT: arith.select
+ // CHECK: linalg.yield
+ // CHECK-NOT: arith.constant 0x7FC00000
+ // CHECK-NOT: tensor.empty()
+ // CHECK-NOT: linalg.fill
+ // CHECK-NOT: tensor.empty()
+ // CHECK-NOT: select
+ // CHECK: return
+ %3 = tosa.reduce_min %arg0 {axis = 0 : i32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>) -> tensor<1x4xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @reduce_max_nan_propagate
+func.func @reduce_max_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
+ // CHECK: linalg.reduce
+ // CHECK: arith.maximumf
+ // CHECK-NOT: arith.cmpf uno
+ // CHECK-NOT: arith.select
+ // CHECK: linalg.yield
+ // CHECK-NOT: arith.constant 0x7FC00000
+ // CHECK-NOT: tensor.empty()
+ // CHECK-NOT: linalg.fill
+ // CHECK-NOT: tensor.empty()
+ // CHECK-NOT: select
+ // CHECK: return
+ %4 = tosa.reduce_max %arg0 {axis = 0 : i32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>) -> tensor<1x4xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @reduce_min_nan_ignore
+func.func @reduce_min_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
+ // CHECK: linalg.reduce
+ // CHECK: arith.minimumf
+ // CHECK: arith.cmpf uno
+ // CHECK: arith.select
+ // CHECK: linalg.yield
+ // CHECK: arith.constant 0x7FC00000
+ // CHECK: tensor.empty()
+ // CHECK: linalg.fill
+ // CHECK: tensor.empty()
+ // CHECK: select
+ %5 = tosa.reduce_min %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xf32>) -> tensor<1x4xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @reduce_max_nan_ignore
+func.func @reduce_max_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
+ // CHECK: linalg.reduce
+ // CHECK: arith.maximumf
+ // CHECK: arith.cmpf uno
+ // CHECK: arith.select
+ // CHECK: linalg.yield
+ // CHECK: arith.constant 0x7FC00000
+ // CHECK: tensor.empty()
+ // CHECK: linalg.fill
+ // CHECK: tensor.empty()
+ // CHECK: select
+ %6 = tosa.reduce_max %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xf32>) -> tensor<1x4xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @minimum_nan_propagate
+func.func @minimum_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
+ // CHECK: linalg.generic
+ // CHECK: arith.minimumf
+ // CHECK-NOT: arith.cmpf uno
+ // CHECK-NOT: arith.select
+ // CHECK: linalg.yield
+ %7 = tosa.minimum %arg0, %arg1 {nan_mode = "PROPAGATE"} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @maximum_nan_propagate
+func.func @maximum_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
+ // CHECK: linalg.generic
+ // CHECK: arith.maximumf
+ // CHECK-NOT: arith.cmpf uno
+ // CHECK-NOT: arith.select
+ // CHECK: linalg.yield
+ %8 = tosa.maximum %arg0, %arg1 {nan_mode = "PROPAGATE"} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @minimum_nan_ignore
+func.func @minimum_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
+ // CHECK: linalg.generic
+ // CHECK: arith.minimumf
+ // CHECK: arith.cmpf uno
+ // CHECK: arith.cmpf uno
+ // CHECK: arith.select
+ // CHECK: arith.select
+ // CHECK: linalg.yield
+ %9 = tosa.minimum %arg0, %arg1 {nan_mode = "IGNORE"} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @maximum_nan_ignore
+func.func @maximum_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
+ // CHECK: linalg.generic
+ // CHECK: arith.maximumf
+ // CHECK: arith.cmpf uno
+ // CHECK: arith.cmpf uno
+ // CHECK: arith.select
+ // CHECK: arith.select
+ // CHECK: linalg.yield
+ %10 = tosa.maximum %arg0, %arg1 {nan_mode = "IGNORE"} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @argmax_nan_propagate
+func.func @argmax_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
+ // CHECK: linalg.generic
+ // CHECK: arith.cmpf ogt
+ // CHECK: arith.select
+ // CHECK: arith.select
+ // CHECK-NOT: arith.cmpf uno
+ // CHECK-NOT: arith.cmpf uno
+ // CHECK-NOT: arith.select
+ // CHECK-NOT: arith.select
+ // CHECK: linalg.yield
+ %11 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>) -> tensor<4xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @argmax_nan_ignore
+func.func @argmax_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
+ // CHECK: linalg.generic
+ // CHECK: arith.cmpf ogt
+ // CHECK: arith.select
+ // CHECK: arith.select
+ // CHECK: arith.cmpf uno
+ // CHECK: arith.select
+ // CHECK: arith.select
+ // CHECK: linalg.yield
+ %12 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xf32>) -> tensor<4xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @clamp_nan_propagate
+func.func @clamp_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
+ // CHECK: linalg.generic
+ // CHECK: arith.minimumf
+ // CHECK: arith.maximumf
+ // CHECK-NOT: arith.cmpf uno
+ // CHECK-NOT: arith.select
+ // CHECK: linalg.yield
+ %13 = tosa.clamp %arg0 {min_val = 1.0 : f32, max_val = 5.0 : f32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>) -> tensor<5x4xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @clamp_nan_ignore
+func.func @clamp_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
+ // CHECK: linalg.generic
+ // CHECK: arith.minimumf
+ // CHECK: arith.maximumf
+ // CHECK: arith.cmpf uno
+ // CHECK: arith.select
+ // CHECK: linalg.yield
+ %14 = tosa.clamp %arg0 {min_val = 1.0 : f32, max_val = 5.0 : f32, nan_mode = "IGNORE"} : (tensor<5x4xf32>) -> tensor<5x4xf32>
+
+ return
+}
More information about the Mlir-commits
mailing list