[Mlir-commits] [mlir] cf3b036 - [mlir][tosa][tosa-to-linalg] Add NaN Mode Lowering (#125668)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 25 08:22:14 PST 2025


Author: Jack Frankland
Date: 2025-02-25T16:22:10Z
New Revision: cf3b0368a55c1c285dd80f12b044b58e87a425ac

URL: https://github.com/llvm/llvm-project/commit/cf3b0368a55c1c285dd80f12b044b58e87a425ac
DIFF: https://github.com/llvm/llvm-project/commit/cf3b0368a55c1c285dd80f12b044b58e87a425ac.diff

LOG: [mlir][tosa][tosa-to-linalg] Add NaN Mode Lowering (#125668)

Add support for NaN propagation lowering in the `tosa-to-linalg` and
`tosa-to-linalg-named` conversions by conditionally checking for NaN in
the case of ignore semantics and materializing the appropriate select
operations. Note that the default behviour of "propagate" matches that
of the arith dialect and so in that case we can avoid creating the
checks altogether.

Add appropriate lit tests including negative tests which check the
various comparisons and selects are materialized as appropriate.

This affects the following TOSA operators:
* arg_max
* max_pool_2d
* clamp
* reduce_max
* reduce_min
* maximum
* minimum

Signed-off-by: Jack Frankland <jack.frankland at arm.com>

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 7b70b3ab8afc9..607667fcc6945 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -32,10 +32,47 @@
 #include "llvm/ADT/Sequence.h"
 
 #include <numeric>
+#include <type_traits>
 
 using namespace mlir;
 using namespace mlir::tosa;
 
+// Helper function to materialize the semantically correct compare and select
+// operations given a binary operation with a specific NaN propagation mode.
+//
+// In the case of "PROPAGATE" semantics no compare and selection is required and
+// this function does nothing.
+//
+// In the case of "IGNORE" semantics this function materializes a comparison of
+// the current operands to the op which will return true for any NaN
+// argument and then selects between the non-NaN operation argument and the
+// calculated result based on whether the lhs or rhs is NaN or not. In pseudo
+// code:
+//
+// binary<op>(lhs, rhs):
+//   result = op(lhs, rhs)
+//   if lhs == NaN return rhs
+//   if rhs == NaN return lhs
+//   return result
+template <typename OpTy>
+static Value
+materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
+                                    Value lhs, Value rhs, Value result) {
+  auto nanMode = op.getNanMode();
+  if (nanMode == "PROPAGATE")
+    return result;
+
+  // Unordered comparison of NaN against itself will always return true.
+  Value lhsIsNaN = rewriter.create<arith::CmpFOp>(
+      op.getLoc(), arith::CmpFPredicate::UNO, lhs, lhs);
+  Value rhsIsNaN = rewriter.create<arith::CmpFOp>(
+      op.getLoc(), arith::CmpFPredicate::UNO, rhs, rhs);
+  Value rhsOrResult =
+      rewriter.create<arith::SelectOp>(op.getLoc(), lhsIsNaN, rhs, result);
+  return rewriter.create<arith::SelectOp>(op.getLoc(), rhsIsNaN, lhs,
+                                          rhsOrResult);
+}
+
 template <typename T>
 static arith::ConstantOp
 createConstFromIntAttribute(Operation *op, const std::string &attrName,
@@ -367,7 +404,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
 
   // tosa::MaximumOp
   if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
-    return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
+    auto max = rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
+    return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MaximumOp>(op),
+                                               rewriter, args[0], args[1], max);
   }
 
   if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
@@ -376,7 +415,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
 
   // tosa::MinimumOp
   if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
-    return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
+    auto min = rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
+    return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MinimumOp>(op),
+                                               rewriter, args[0], args[1], min);
   }
 
   if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
@@ -404,7 +445,31 @@ static Value createLinalgBodyCalculationForElementwiseOp(
         loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
     auto max = rewriter.create<arith::ConstantOp>(
         loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
-    return clampFloatHelper(loc, args[0], min, max, rewriter);
+    auto result = clampFloatHelper(loc, args[0], min, max, rewriter);
+
+    auto clampOp = llvm::cast<tosa::ClampOp>(op);
+    const auto nanMode = clampOp.getNanMode();
+    // In the case of "PROPAGATE" semantics no compare and selection is
+    // required.
+    if (nanMode == "PROPAGATE")
+      return result;
+
+    // In the case of "IGNORE" semantics materialize a comparison
+    // of the current operand to the reduction which will return true for a NaN
+    // argument and then selects between the initial reduction value and the
+    // calculated result based on whether the argument is NaN or not. In pseudo
+    // code:
+    //
+    // reduce<op>(x, init):
+    //   result = op(init, x)
+    //   return init if x == NaN else result
+
+    // Unordered comparison of NaN against itself will always return true.
+    Value isNaN = rewriter.create<arith::CmpFOp>(
+        op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
+    // TOSA specifies that in "ignore" NaN mode the result is "min" if the input
+    // is NaN.
+    return rewriter.create<arith::SelectOp>(op->getLoc(), isNaN, min, result);
   }
 
   if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
@@ -1078,7 +1143,8 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
 // Performs the match and rewrite for reduction operations. This includes
 // declaring a correctly sized initial value, and the linalg.generic operation
 // that reduces across the specified axis.
-static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
+template <typename OpTy>
+static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
                                                  PatternRewriter &rewriter) {
   auto loc = op->getLoc();
   auto inputTy = cast<ShapedType>(op->getOperand(0).getType());
@@ -1096,6 +1162,9 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
     }
   }
 
+  SmallVector<Value> inputs, outputs;
+  inputs.push_back(input);
+
   // First fill the output buffer with the init value.
   auto emptyTensor =
       rewriter
@@ -1113,26 +1182,127 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
                           .create<linalg::FillOp>(loc, ValueRange{fillValue},
                                                   ValueRange{emptyTensor})
                           .result();
+  outputs.push_back(filledTensor);
+
+  bool isNanIgnoreMode = false;
+  if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
+                std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
+    if (op.getNanMode() == "IGNORE") {
+      isNanIgnoreMode = true;
+      // Because the TOSA spec requires the result be NaN iff all elements in
+      // the reduction are NaN we can't simply perform a compare and select.
+      // Additionally we have to keep track of whether we've seen any non-NaN
+      // values and then do a final select based on this predicate.
+      auto trueAttr = rewriter.getBoolAttr(true);
+      auto trueValue = rewriter.create<arith::ConstantOp>(loc, trueAttr);
+      auto emptyBoolTensor =
+          rewriter
+              .create<tensor::EmptyOp>(loc, reduceShape, trueValue.getType(),
+                                       dynDims)
+              .getResult();
+      auto allResultsNaNTensor =
+          rewriter
+              .create<linalg::FillOp>(loc, ValueRange{trueValue},
+                                      ValueRange{emptyBoolTensor})
+              .result();
+      // Note that because the linalg::ReduceOp has two variadic arguments
+      // (inputs and outputs) and it has the SameVariadicOperandSize trait we
+      // need to have the same number of inputs and outputs.
+      //
+      // The second input isn't actually used anywhere since the value used to
+      // update the NaN flag is calculated inside the body of the reduction and
+      // then used to update an out value.
+      // In order to satisfy type constraints we just pass another copy of the
+      // input here.
+      inputs.push_back(input);
+      outputs.push_back(allResultsNaNTensor);
+    }
+  }
 
   bool didEncounterError = false;
-  auto linalgOp = rewriter.create<linalg::ReduceOp>(
-      loc, input, filledTensor, axis,
+  linalg::LinalgOp linalgOp = rewriter.create<linalg::ReduceOp>(
+      loc, inputs, outputs, axis,
       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
+        std::array<Value, 2> binaryArgs{
+            blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
         auto result = createLinalgBodyCalculationForReduceOp(
-            op, blockArgs, elementTy, rewriter);
+            op, binaryArgs, elementTy, rewriter);
         if (result)
           didEncounterError = true;
 
-        nestedBuilder.create<linalg::YieldOp>(loc, result);
+        SmallVector<Value> resultsToYield;
+        if (isNanIgnoreMode) {
+          auto inputValue = blockArgs[0];
+          auto initialValue = blockArgs[2];
+          auto oldAllResultsNanFlagValue = blockArgs[3];
+
+          // Unordered comparison of NaN against itself will always return true.
+          Value isNaN = nestedBuilder.create<arith::CmpFOp>(
+              op->getLoc(), arith::CmpFPredicate::UNO, inputValue, inputValue);
+          // If we've encountered a NaN, take the non-NaN value.
+          auto selectOp = nestedBuilder.create<arith::SelectOp>(
+              op->getLoc(), isNaN, initialValue, result);
+          // Update the flag which keeps track of whether we have seen a non-NaN
+          // value.
+          auto newAllResultsNanFlagValue = nestedBuilder.create<arith::AndIOp>(
+              op->getLoc(), oldAllResultsNanFlagValue, isNaN);
+          resultsToYield.push_back(selectOp);
+          resultsToYield.push_back(newAllResultsNanFlagValue);
+        } else {
+          resultsToYield.push_back(result);
+        }
+        nestedBuilder.create<linalg::YieldOp>(loc, resultsToYield);
       });
 
   if (!didEncounterError)
     return rewriter.notifyMatchFailure(
         op, "unable to create linalg.generic body for reduce op");
 
+  if (isNanIgnoreMode) {
+    // Materialize a check to see whether we encountered any non-NaN values, if
+    // we didn't we need to select a tensor of NaNs since the result will just
+    // be the initial identity value propagated through all the compares and
+    // selects inside the reduction.
+
+    // Create a tensor full of NaNs.
+    auto nanValueAttr = rewriter.getFloatAttr(
+        elementTy,
+        APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
+    auto nanValue = rewriter.create<arith::ConstantOp>(loc, nanValueAttr);
+    auto emptyNanTensor =
+        rewriter
+            .create<tensor::EmptyOp>(loc, reduceShape,
+                                     resultTy.getElementType(), dynDims)
+            .getResult();
+    auto nanFilledTensor =
+        rewriter
+            .create<linalg::FillOp>(loc, ValueRange{nanValue},
+                                    ValueRange{emptyNanTensor})
+            .result();
+
+    // Create an empty tensor, non need to fill this since it will be
+    // overwritten by the select.
+    auto finalEmptyTensor =
+        rewriter
+            .create<tensor::EmptyOp>(loc, reduceShape,
+                                     resultTy.getElementType(), dynDims)
+            .getResult();
+
+    // Do a selection between the tensors akin to:
+    // result = NaN if "all results NaN" else result.
+    SmallVector<Value> ins, outs;
+    ins.push_back(linalgOp->getOpResult(1));
+    ins.push_back(nanFilledTensor);
+    ins.push_back(linalgOp->getResult(0));
+    outs.push_back(finalEmptyTensor);
+    auto linalgSelect =
+        rewriter.create<linalg::SelectOp>(op->getLoc(), ins, outs);
+    linalgOp = linalgSelect;
+  }
+
   SmallVector<ReassociationExprs, 4> reassociationMap;
   uint64_t expandInputRank =
-      cast<ShapedType>(linalgOp.getResults()[0].getType()).getRank();
+      cast<ShapedType>(linalgOp->getResults()[0].getType()).getRank();
   reassociationMap.resize(expandInputRank);
 
   for (uint64_t i = 0; i < expandInputRank; i++) {
@@ -1151,7 +1321,7 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
   // not have access to such information. This matters when handling dynamically
   // sized tensors.
   rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
-      op, resultTy, linalgOp.getResults()[0], reassociationMap);
+      op, resultTy, linalgOp->getResults()[0], reassociationMap);
   return success();
 }
 
@@ -2097,6 +2267,27 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
               nestedLoc, predicate, newValue, oldValue);
           auto resultIndex = rewriter.create<arith::SelectOp>(
               nestedLoc, predicate, newIndex, oldIndex);
+
+          // Check if we need to materialize compare and select for the given
+          // NaN propagation mode.
+
+          // "PROPAGATE" matches the default NaN propagation mode of the arith
+          // dialect so no compare and select is required.
+          //
+          // In the case "IGNORE" we check if the current argument is NaN and
+          // select the old index and value otherwise take the updated index and
+          // value.
+          if (const auto nanMode = argmaxOp.getNanMode(); nanMode == "IGNORE") {
+            // Unordered comparison of NaN against itself will always return
+            // true.
+            Value isNaN = rewriter.create<arith::CmpFOp>(
+                argmaxOp.getLoc(), arith::CmpFPredicate::UNO, newValue,
+                newValue);
+            resultMax = rewriter.create<arith::SelectOp>(nestedLoc, isNaN,
+                                                         oldValue, resultMax);
+            resultIndex = rewriter.create<arith::SelectOp>(
+                nestedLoc, isNaN, oldIndex, resultIndex);
+          }
           nestedBuilder.create<linalg::YieldOp>(
               nestedLoc, ValueRange({resultIndex, resultMax}));
         });

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index a8fd536dd2548..620b5f95825f6 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -724,11 +724,44 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
       rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(
           op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
           filledEmptyTensor, strideAttr, dilationAttr);
-    } else {
-      rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
-          op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
-          filledEmptyTensor, strideAttr, dilationAttr);
+      return llvm::success();
     }
+
+    auto resultOp = rewriter.create<linalg::PoolingNhwcMaxOp>(
+        op->getLoc(), ArrayRef<Type>{resultTy},
+        ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr,
+        dilationAttr);
+
+    rewriter.replaceOp(op, resultOp);
+    // "PROPAGATE" mode matches the behaviour of the LinAlg named op, so no
+    // compare and select materialization is required.
+    //
+    // In the case of "IGNORE" we need to insert a compare and select. Since
+    // we've already produced a named op we will just take its body and modify
+    // it to include the appropriate checks. If the current value is NaN the
+    // old value of pool will be taken otherwise we use the result.
+    if (const auto nanMode = op.getNanMode(); nanMode == "IGNORE") {
+      auto genericOp = rewriter.create<linalg::GenericOp>(
+          op->getLoc(), resultOp.getType(0), resultOp.getInputs(),
+          resultOp.getOutputs(), resultOp.getIndexingMapsArray(),
+          resultOp.getIteratorTypesArray(),
+          [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
+            IRMapping map;
+            auto oldBlock = resultOp.getRegion().begin();
+            auto oldArgs = oldBlock->getArguments();
+            auto &oldMaxOp = *resultOp.getBlock()->begin();
+            map.map(oldArgs, blockArgs);
+            auto *newOp = opBuilder.clone(oldMaxOp, map);
+            Value isNaN = opBuilder.create<arith::CmpFOp>(
+                op->getLoc(), arith::CmpFPredicate::UNO, blockArgs.front(),
+                blockArgs.front());
+            auto selectOp = opBuilder.create<arith::SelectOp>(
+                op->getLoc(), isNaN, blockArgs.back(), newOp->getResult(0));
+            opBuilder.create<linalg::YieldOp>(loc, selectOp.getResult());
+          });
+      rewriter.replaceOp(resultOp, genericOp);
+    }
+
     return success();
   }
 };

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index a524359b49759..980805ad94b7a 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -906,3 +906,27 @@ func.func @test_transpose_dyn_multiple_3d(%arg0: tensor<?x?x?xf32>) {
   %1 = "tosa.transpose"(%arg0, %0) : (tensor<?x?x?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
   return
 }
+
+// -----
+
+// CHECK-LABEL: @max_pool2d_nan_propagate
+func.func @max_pool2d_nan_propagate(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x4x32x62xf32>) {
+  // CHECK: linalg.pooling_nhwc_max
+  // CHECK-NOT: linalg.generic
+  %0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = "PROPAGATE"} : (tensor<1x6x34x62xf32>) -> tensor<1x4x32x62xf32>
+  return %0 : tensor<1x4x32x62xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @max_pool2d_nan_ignore
+func.func @max_pool2d_nan_ignore(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x4x32x62xf32>) {
+  // CHECK-NOT: linalg.pooling_nhwc_max
+  // CHECK: linalg.generic
+  // CHECK: arith.maximumf
+  // CHECK: arith.cmpf uno
+  // CHECK: arith.select
+  // CHECK: linalg.yield
+  %0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = "IGNORE"} : (tensor<1x6x34x62xf32>) -> tensor<1x4x32x62xf32>
+  return %0: tensor<1x4x32x62xf32>
+}

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 17add2d41afe7..86e6f9ed9264b 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1992,3 +1992,195 @@ func.func @test_cast_fp32_i64(%arg0: tensor<1xf32>) -> (tensor<1xi64>) {
   %0 = tosa.cast %arg0 : (tensor<1xf32>) -> tensor<1xi64>
   return %0: tensor<1xi64>
 }
+
+// -----
+
+// CHECK-LABEL: @reduce_min_nan_propagate
+func.func @reduce_min_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
+  // CHECK: linalg.reduce
+  // CHECK: arith.minimumf
+  // CHECK-NOT: arith.cmpf uno
+  // CHECK-NOT: arith.select
+  // CHECK: linalg.yield
+  // CHECK-NOT: arith.constant 0x7FC00000
+  // CHECK-NOT: tensor.empty()
+  // CHECK-NOT: linalg.fill
+  // CHECK-NOT: tensor.empty()
+  // CHECK-NOT: select
+  // CHECK: return
+  %3 = tosa.reduce_min %arg0 {axis = 0 : i32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>) -> tensor<1x4xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @reduce_max_nan_propagate
+func.func @reduce_max_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
+  // CHECK: linalg.reduce
+  // CHECK: arith.maximumf
+  // CHECK-NOT: arith.cmpf uno
+  // CHECK-NOT: arith.select
+  // CHECK: linalg.yield
+  // CHECK-NOT: arith.constant 0x7FC00000
+  // CHECK-NOT: tensor.empty()
+  // CHECK-NOT: linalg.fill
+  // CHECK-NOT: tensor.empty()
+  // CHECK-NOT: select
+  // CHECK: return
+  %4 = tosa.reduce_max %arg0 {axis = 0 : i32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>) -> tensor<1x4xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @reduce_min_nan_ignore
+func.func @reduce_min_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
+  // CHECK: linalg.reduce
+  // CHECK: arith.minimumf
+  // CHECK: arith.cmpf uno
+  // CHECK: arith.select
+  // CHECK: linalg.yield
+  // CHECK: arith.constant 0x7FC00000
+  // CHECK: tensor.empty()
+  // CHECK: linalg.fill
+  // CHECK: tensor.empty()
+  // CHECK: select
+  %5 = tosa.reduce_min %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xf32>) -> tensor<1x4xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @reduce_max_nan_ignore
+func.func @reduce_max_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
+  // CHECK: linalg.reduce
+  // CHECK: arith.maximumf
+  // CHECK: arith.cmpf uno
+  // CHECK: arith.select
+  // CHECK: linalg.yield
+  // CHECK: arith.constant 0x7FC00000
+  // CHECK: tensor.empty()
+  // CHECK: linalg.fill
+  // CHECK: tensor.empty()
+  // CHECK: select
+  %6 = tosa.reduce_max %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xf32>) -> tensor<1x4xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @minimum_nan_propagate
+func.func @minimum_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
+  // CHECK: linalg.generic
+  // CHECK: arith.minimumf
+  // CHECK-NOT: arith.cmpf uno
+  // CHECK-NOT: arith.select
+  // CHECK: linalg.yield
+  %7 = tosa.minimum %arg0, %arg1 {nan_mode = "PROPAGATE"} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @maximum_nan_propagate
+func.func @maximum_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
+  // CHECK: linalg.generic
+  // CHECK: arith.maximumf
+  // CHECK-NOT: arith.cmpf uno
+  // CHECK-NOT: arith.select
+  // CHECK: linalg.yield
+  %8 = tosa.maximum %arg0, %arg1 {nan_mode = "PROPAGATE"} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @minimum_nan_ignore
+func.func @minimum_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
+  // CHECK: linalg.generic
+  // CHECK: arith.minimumf
+  // CHECK: arith.cmpf uno
+  // CHECK: arith.cmpf uno
+  // CHECK: arith.select
+  // CHECK: arith.select
+  // CHECK: linalg.yield
+  %9 = tosa.minimum %arg0, %arg1 {nan_mode = "IGNORE"} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @maximum_nan_ignore
+func.func @maximum_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
+  // CHECK: linalg.generic
+  // CHECK: arith.maximumf
+  // CHECK: arith.cmpf uno
+  // CHECK: arith.cmpf uno
+  // CHECK: arith.select
+  // CHECK: arith.select
+  // CHECK: linalg.yield
+  %10 = tosa.maximum %arg0, %arg1 {nan_mode = "IGNORE"} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @argmax_nan_propagate
+func.func @argmax_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
+  // CHECK: linalg.generic
+  // CHECK: arith.cmpf ogt
+  // CHECK: arith.select
+  // CHECK: arith.select
+  // CHECK-NOT: arith.cmpf uno
+  // CHECK-NOT: arith.cmpf uno
+  // CHECK-NOT: arith.select
+  // CHECK-NOT: arith.select
+  // CHECK: linalg.yield
+  %11 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>)  -> tensor<4xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @argmax_nan_ignore
+func.func @argmax_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
+  // CHECK: linalg.generic
+  // CHECK: arith.cmpf ogt
+  // CHECK: arith.select
+  // CHECK: arith.select
+  // CHECK: arith.cmpf uno
+  // CHECK: arith.select
+  // CHECK: arith.select
+  // CHECK: linalg.yield
+  %12 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xf32>)  -> tensor<4xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @clamp_nan_propagate
+func.func @clamp_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
+  // CHECK: linalg.generic
+  // CHECK: arith.minimumf
+  // CHECK: arith.maximumf
+  // CHECK-NOT: arith.cmpf uno
+  // CHECK-NOT: arith.select
+  // CHECK: linalg.yield
+  %13 = tosa.clamp %arg0 {min_val =  1.0 : f32, max_val = 5.0 : f32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>) -> tensor<5x4xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @clamp_nan_ignore
+func.func @clamp_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
+  // CHECK: linalg.generic
+  // CHECK: arith.minimumf
+  // CHECK: arith.maximumf
+  // CHECK: arith.cmpf uno
+  // CHECK: arith.select
+  // CHECK: linalg.yield
+  %14 = tosa.clamp %arg0 {min_val = 1.0 : f32, max_val = 5.0 : f32, nan_mode = "IGNORE"} : (tensor<5x4xf32>) -> tensor<5x4xf32>
+
+  return
+}


        


More information about the Mlir-commits mailing list