[Mlir-commits] [mlir] [mlir][tosa][tosa-to-linalg] Add NaN Mode Lowering (PR #125668)

Jack Frankland llvmlistbot at llvm.org
Wed Feb 12 07:23:46 PST 2025


https://github.com/FranklandJack updated https://github.com/llvm/llvm-project/pull/125668

>From 63d250259adb0ed9846f8c79ef4a473f9446588d Mon Sep 17 00:00:00 2001
From: Jack Frankland <jack.frankland at arm.com>
Date: Mon, 3 Feb 2025 18:36:09 +0000
Subject: [PATCH] [mlir][tosa][tosa-to-linalg] Add NaN Mode Lowering

Add support for NaN propagation lowering in the `tosa-to-linalg` and
`tosa-to-linalg-named` conversions by conditionally checking for NaN in
the case of ignore semantics and materializing the appropriate select
operations. Note that the default behviour of "propagate" matches
that of the arith dialect and so in that case we can avoid creating the
checks altogether.

Add appropriate lit tests including negative tests which check the
various comparisons and selects are materialized as appropriate.

This affects the following TOSA operators:
* arg_max
* max_pool_2d
* clamp
* reduce_max
* reduce_min
* maximum
* minimum

Signed-off-by: Jack Frankland <jack.frankland at arm.com>
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h   |  20 ++
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  | 217 +++++++++++++++++-
 .../TosaToLinalg/TosaToLinalgNamed.cpp        |  44 +++-
 .../TosaToLinalg/tosa-to-linalg-named.mlir    |  22 ++
 .../TosaToLinalg/tosa-to-linalg.mlir          | 124 ++++++++++
 5 files changed, 414 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 069073bc2d164..ff5db80a82633 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -267,6 +267,26 @@ extractConvZpPair(TosaConvOp op, PatternRewriter &rewriter) {
 
   return std::make_optional<ConvZpPair>(*maybeInputZp, *maybeWeightZp);
 }
+
+// Helper function to extract the NaN propagation mode from an operation.
+// Note that the for operations which support NaN mode propagation the attribute
+// is optional and its default value is "PROPAGATE".
+//
+// If the function is called with an operator that doesn't support the NaN mode
+// attribute it will return a std::nullopt.
+inline std::optional<std::string> getNanMode(Operation *op,
+                                             PatternRewriter &rewriter) {
+  if (isa<tosa::ClampOp>(op) || isa<tosa::MaxPool2dOp>(op) ||
+      isa<tosa::ReduceMinOp>(op) || isa<tosa::ReduceMaxOp>(op) ||
+      isa<tosa::MaximumOp>(op) || isa<tosa::MinimumOp>(op) ||
+      isa<tosa::ArgMaxOp>(op))
+    return op->hasAttr("nan_mode") ? op->getAttrOfType<StringAttr>(
+                                           rewriter.getStringAttr("nan_mode"))
+                                         .str()
+                                   : "PROPAGATE";
+  return std::nullopt;
+}
+
 } // namespace tosa
 } // namespace mlir
 
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index d849c782bf08b..be2babaf64179 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -36,6 +36,47 @@
 using namespace mlir;
 using namespace mlir::tosa;
 
+// Helper function to materialize the semantically correct compare and select
+// operations a binary operation with a specific NaN propagation mode.
+//
+// In the case of "PROPAGATE" semantics no compare and selection is required and
+// this function does nothing.
+//
+// In the case of "IGNORE" semantics this function materializes a comparison of
+// the current operands to the op which will return true for any NaN
+// argument and then selects between the non-NaN operation argument and the
+// calculated result based on whether the lhs or rhs is NaN or not. In pseudo
+// code:
+//
+// binary<op>(lhs, rhs):
+//   result = op(lhs, rhs)
+//   if lhs == NaN return rhs
+//   if rhs == NaN return lhs
+//   return result
+static Value materializeBinaryNanCheckIfRequired(Operation *op,
+                                                 PatternRewriter &rewriter,
+                                                 Value lhs, Value rhs,
+                                                 Value result) {
+  const auto nanMode = getNanMode(op, rewriter);
+  if (!nanMode)
+    return {};
+
+  if (*nanMode == "PROPAGATE")
+    return result;
+
+  assert(*nanMode == "IGNORE" && "Unhandled nan-propagation mode");
+
+  // Unordered comparison of NaN against itself will always return true.
+  Value lhsIsNaN = rewriter.create<arith::CmpFOp>(
+      op->getLoc(), arith::CmpFPredicate::UNO, lhs, lhs);
+  Value rhsIsNaN = rewriter.create<arith::CmpFOp>(
+      op->getLoc(), arith::CmpFPredicate::UNO, rhs, rhs);
+  Value rhsOrResult =
+      rewriter.create<arith::SelectOp>(op->getLoc(), lhsIsNaN, rhs, result);
+  return rewriter.create<arith::SelectOp>(op->getLoc(), rhsIsNaN, lhs,
+                                          rhsOrResult);
+}
+
 template <typename T>
 static arith::ConstantOp
 createConstFromIntAttribute(Operation *op, const std::string &attrName,
@@ -367,7 +408,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
 
   // tosa::MaximumOp
   if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
-    return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
+    auto max = rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
+    return materializeBinaryNanCheckIfRequired(op, rewriter, args[0], args[1],
+                                               max);
   }
 
   if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
@@ -376,7 +419,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
 
   // tosa::MinimumOp
   if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
-    return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
+    auto min = rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
+    return materializeBinaryNanCheckIfRequired(op, rewriter, args[0], args[1],
+                                               min);
   }
 
   if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
@@ -404,7 +449,34 @@ static Value createLinalgBodyCalculationForElementwiseOp(
         loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
     auto max = rewriter.create<arith::ConstantOp>(
         loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
-    return clampFloatHelper(loc, args[0], min, max, rewriter);
+    auto result = clampFloatHelper(loc, args[0], min, max, rewriter);
+
+    const auto nanMode = getNanMode(op, rewriter);
+    if (!nanMode)
+      return {};
+
+    // In the case of "PROPAGATE" semantics no compare and selection is
+    // required.
+    if (*nanMode == "PROPAGATE")
+      return result;
+
+    // In the case of "IGNORE" semantics materialize a comparison
+    // of the current operand to the reduction which will return true for a NaN
+    // argument and then selects between the initial reduction value and the
+    // calculated result based on whether the argument is NaN or not. In pseudo
+    // code:
+    //
+    // reduce<op>(x, init):
+    //   result = op(init, x)
+    //   return init if x == NaN else result
+    assert(*nanMode == "IGNORE" && "Unhandled nan-propagation mode");
+
+    // Unordered comparison of NaN against itself will always return true.
+    Value isNaN = rewriter.create<arith::CmpFOp>(
+        op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
+    // TOSA specifies that in "ignore" NaN mode the result is "min" if the input
+    // is NaN.
+    return rewriter.create<arith::SelectOp>(op->getLoc(), isNaN, min, result);
   }
 
   if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
@@ -1096,6 +1168,9 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
     }
   }
 
+  SmallVector<Value> inputs, outputs;
+  inputs.push_back(input);
+
   // First fill the output buffer with the init value.
   auto emptyTensor =
       rewriter
@@ -1113,26 +1188,124 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
                           .create<linalg::FillOp>(loc, ValueRange{fillValue},
                                                   ValueRange{emptyTensor})
                           .result();
+  outputs.push_back(filledTensor);
+
+  const auto nanMode = getNanMode(op, rewriter);
+  const bool isNanIgnoreMode = (nanMode && *nanMode == "IGNORE");
+  if (isNanIgnoreMode) {
+    // Because the TOSA spec requires the result be NaN iff all elements in the
+    // reduction are NaN we can't simply perform a compare and select.
+    // Additionally we have to keep track of whether we've seen any non-NaN
+    // values and then do a final select based on this predicate.
+    auto trueAttr = rewriter.getBoolAttr(true);
+    auto trueValue = rewriter.create<arith::ConstantOp>(loc, trueAttr);
+    auto emptyBoolTensor =
+        rewriter
+            .create<tensor::EmptyOp>(loc, reduceShape, trueValue.getType(),
+                                     dynDims)
+            .getResult();
+    auto allResultsNaNTensor =
+        rewriter
+            .create<linalg::FillOp>(loc, ValueRange{trueValue},
+                                    ValueRange{emptyBoolTensor})
+            .result();
+    // Note that because the linalg::ReduceOp has two variadic arguments (inputs
+    // and outputs) and it has the SameVariadicOperandSize trait we need to have
+    // the same number of inputs and outputs.
+    //
+    // The second input isn't actully used anywhere since the value used to
+    // update the NaN flag is calculated inside the body of the reduction and
+    // then used to update an out value.
+    // In order to satisfy type constraints we just pass another copy of the
+    // input here.
+    inputs.push_back(input);
+    outputs.push_back(allResultsNaNTensor);
+  }
 
   bool didEncounterError = false;
-  auto linalgOp = rewriter.create<linalg::ReduceOp>(
-      loc, input, filledTensor, axis,
+  linalg::LinalgOp linalgOp = rewriter.create<linalg::ReduceOp>(
+      loc, inputs, outputs, axis,
       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
+        std::array<Value, 2> binaryArgs{
+            blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
         auto result = createLinalgBodyCalculationForReduceOp(
-            op, blockArgs, elementTy, rewriter);
+            op, binaryArgs, elementTy, rewriter);
         if (result)
           didEncounterError = true;
 
-        nestedBuilder.create<linalg::YieldOp>(loc, result);
+        SmallVector<Value> resultsToYield;
+        if (isNanIgnoreMode) {
+          auto inputValue = blockArgs[0];
+          auto initialValue = blockArgs[2];
+          auto oldAllResultsNanFlagValue = blockArgs[3];
+
+          // Unordered comparison of NaN against itself will always return true.
+          Value isNaN = nestedBuilder.create<arith::CmpFOp>(
+              op->getLoc(), arith::CmpFPredicate::UNO, inputValue, inputValue);
+          // If we've encountered a NaN, take the non-NaN value.
+          auto selectOp = nestedBuilder.create<arith::SelectOp>(
+              op->getLoc(), isNaN, initialValue, result);
+          // Update the flag which keeps track of whether we have seen a non-NaN
+          // value.
+          auto newAllResultsNanFlagValue = nestedBuilder.create<arith::AndIOp>(
+              op->getLoc(), oldAllResultsNanFlagValue, isNaN);
+          resultsToYield.push_back(selectOp);
+          resultsToYield.push_back(newAllResultsNanFlagValue);
+        } else {
+          resultsToYield.push_back(result);
+        }
+        nestedBuilder.create<linalg::YieldOp>(loc, resultsToYield);
       });
 
   if (!didEncounterError)
     return rewriter.notifyMatchFailure(
         op, "unable to create linalg.generic body for reduce op");
 
+  if (isNanIgnoreMode) {
+    // Materialize a check to see whether we encountered any non-NaN values, if
+    // we didn't we need to select a tensor of NaNs since the result will just
+    // be the initial identity value propagated through all the compares and
+    // selects inside the reduction.
+
+    // Create a tensor full of NaNs.
+    auto nanValueAttr = rewriter.getFloatAttr(
+        elementTy,
+        APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
+    auto nanValue = rewriter.create<arith::ConstantOp>(loc, nanValueAttr);
+    auto emptyNanTensor =
+        rewriter
+            .create<tensor::EmptyOp>(loc, reduceShape,
+                                     resultTy.getElementType(), dynDims)
+            .getResult();
+    auto nanFilledTensor =
+        rewriter
+            .create<linalg::FillOp>(loc, ValueRange{nanValue},
+                                    ValueRange{emptyNanTensor})
+            .result();
+
+    // Create an empty tensor, non need to fill this since it will be
+    // overwritten by the select.
+    auto finalEmptyTensor =
+        rewriter
+            .create<tensor::EmptyOp>(loc, reduceShape,
+                                     resultTy.getElementType(), dynDims)
+            .getResult();
+
+    // Do a selection between the tensors akin to:
+    // result = NaN if "all results NaN" else result.
+    SmallVector<Value> ins, outs;
+    ins.push_back(linalgOp->getOpResult(1));
+    ins.push_back(nanFilledTensor);
+    ins.push_back(linalgOp->getResult(0));
+    outs.push_back(finalEmptyTensor);
+    auto linalgSelect =
+        rewriter.create<linalg::SelectOp>(op->getLoc(), ins, outs);
+    linalgOp = linalgSelect;
+  }
+
   SmallVector<ReassociationExprs, 4> reassociationMap;
   uint64_t expandInputRank =
-      cast<ShapedType>(linalgOp.getResults()[0].getType()).getRank();
+      cast<ShapedType>(linalgOp->getResults()[0].getType()).getRank();
   reassociationMap.resize(expandInputRank);
 
   for (uint64_t i = 0; i < expandInputRank; i++) {
@@ -1151,7 +1324,7 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
   // not have access to such information. This matters when handling dynamically
   // sized tensors.
   rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
-      op, resultTy, linalgOp.getResults()[0], reassociationMap);
+      op, resultTy, linalgOp->getResults()[0], reassociationMap);
   return success();
 }
 
@@ -2088,6 +2261,32 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
               nestedLoc, predicate, newValue, oldValue);
           auto resultIndex = rewriter.create<arith::SelectOp>(
               nestedLoc, predicate, newIndex, oldIndex);
+
+          // Check if we need to materialize compare and select for the given
+          // NaN propagation mode.
+          const auto nanMode = getNanMode(argmaxOp, rewriter);
+          if (!nanMode) {
+            didEncounterError = true;
+            return;
+          }
+
+          // "PROPAGATE" matches the default NaN propagation mode of the arith
+          // dialect so no compare and select is required.
+          //
+          // In the case "IGNORE" we check if the current argument is NaN and
+          // select the old index and value otherwise take the updated index and
+          // value.
+          if (*nanMode == "IGNORE") {
+            // Unordered comparison of NaN against itself will always return
+            // true.
+            Value isNaN = rewriter.create<arith::CmpFOp>(
+                argmaxOp.getLoc(), arith::CmpFPredicate::UNO, newValue,
+                newValue);
+            resultMax = rewriter.create<arith::SelectOp>(nestedLoc, isNaN,
+                                                         oldValue, resultMax);
+            resultIndex = rewriter.create<arith::SelectOp>(
+                nestedLoc, isNaN, oldIndex, resultIndex);
+          }
           nestedBuilder.create<linalg::YieldOp>(
               nestedLoc, ValueRange({resultIndex, resultMax}));
         });
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 6321cb6087394..441ae0f456ca5 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -802,11 +802,47 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
       rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(
           op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
           filledEmptyTensor, strideAttr, dilationAttr);
-    } else {
-      rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
-          op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
-          filledEmptyTensor, strideAttr, dilationAttr);
+      return llvm::success();
     }
+
+    auto resultOp = rewriter.create<linalg::PoolingNhwcMaxOp>(
+        op->getLoc(), ArrayRef<Type>{resultTy},
+        ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr,
+        dilationAttr);
+
+    // Check the NaN propgation mode is present.
+    const auto nanMode = getNanMode(op, rewriter);
+    if (!nanMode)
+      return failure();
+
+    // "PROPAGATE" mode matches the behaviour of the LinAlg named op, so no
+    // compare and select materialization is required.
+    //
+    // In the case of "IGNORE" we need to insert a compare and select. Since
+    // we've already produced a named op we will just take its body and modify
+    // it to include the appropriate checks. If the current value is NaN the
+    // old value of pool will be taken otherwise we use the result.
+    if (nanMode == "IGNORE") {
+      auto *block = resultOp.getBlock();
+      rewriter.setInsertionPointToEnd(block);
+
+      auto in = block->getArgument(0);
+      auto out = block->getArgument(2);
+
+      auto *oldYieldOp = &*block->rbegin();
+      auto result = oldYieldOp->getOperand(0);
+
+      Value isNaN = rewriter.create<arith::CmpFOp>(
+          op->getLoc(), arith::CmpFPredicate::UNO, in, in);
+
+      auto selectOp =
+          rewriter.create<arith::SelectOp>(op->getLoc(), isNaN, out, result);
+      auto newYieldOp = rewriter.create<linalg::YieldOp>(oldYieldOp->getLoc(),
+                                                         selectOp.getResult());
+      rewriter.replaceOp(oldYieldOp, newYieldOp);
+    }
+
+    rewriter.replaceOp(op, resultOp);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 87c388b6f5ee3..a70f07f0aa358 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -1,6 +1,7 @@
 // RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s
 // RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{prefer-conv2d-kernel-layout-hwcf=true}))" %s -verify-diagnostics -o -| FileCheck --check-prefix="HWCF" %s
 // RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,cse))" %s -verify-diagnostics -o -| FileCheck --check-prefix="CHECK-CSE" %s
+// RUN: mlir-opt --verify-each --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,linalg-generalize-named-ops))" %s -verify-diagnostics -o -| FileCheck %s --check-prefix="CHECK-NAN"
 
 // CHECK-LABEL: @matmul
 func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) {
@@ -977,3 +978,24 @@ func.func @test_transpose_dyn_multiple_3d(%arg0: tensor<?x?x?xf32>) {
   %1 = "tosa.transpose"(%arg0, %0) : (tensor<?x?x?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
   return
 }
+
+// -----
+
+// CHECK-NAN-LABEL: @nan_propagation_modes
+func.func @nan_propagation_modes(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x4x32x62xf32>, tensor<1x4x32x62xf32>) {
+  // CHECK-NAN: linalg.generic
+  // CHECK-NAN-NOT: arith.maximumf
+  // CHECK-NAN-NOT: arith.cmpf uno
+  // CHECK-NAN-NOT: arith.select
+  // CHECK-NAN: linalg.yield
+  %0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = "PROPAGATE"} : (tensor<1x6x34x62xf32>) -> tensor<1x4x32x62xf32>
+
+  // CHECK-NAN: linalg.generic
+  // CHECK-NAN: arith.maximumf
+  // CHECK-NAN: arith.cmpf uno
+  // CHECK-NAN: arith.select
+  // CHECK-NAN: linalg.yield
+  %1 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = "IGNORE"} : (tensor<1x6x34x62xf32>) -> tensor<1x4x32x62xf32>
+
+  return %0, %1 : tensor<1x4x32x62xf32>, tensor<1x4x32x62xf32>
+}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 17add2d41afe7..51160fee2e7e1 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1992,3 +1992,127 @@ func.func @test_cast_fp32_i64(%arg0: tensor<1xf32>) -> (tensor<1xi64>) {
   %0 = tosa.cast %arg0 : (tensor<1xf32>) -> tensor<1xi64>
   return %0: tensor<1xi64>
 }
+
+// -----
+
+// CHECK-LABEL: @nan_propagation_modes
+func.func @nan_propagation_modes(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
+  // CHECK: linalg.reduce
+  // CHECK: arith.minimumf
+  // CHECK-NOT: arith.cmpf uno
+  // CHECK-NOT: arith.select
+  // CHECK: linalg.yield
+  // CHECK-NOT: arith.constant 0x7FC00000
+  // CHECK-NOT: tensor.empty()
+  // CHECK-NOT: linalg.fill
+  // CHECK-NOT: tensor.empty()
+  // CHECK-NOT: select
+  %3 = tosa.reduce_min %arg0 {axis = 0 : i32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>) -> tensor<1x4xf32>
+
+  // CHECK: linalg.reduce
+  // CHECK: arith.maximumf
+  // CHECK-NOT: arith.cmpf uno
+  // CHECK-NOT: arith.select
+  // CHECK: linalg.yield
+  // CHECK-NOT: arith.constant 0x7FC00000
+  // CHECK-NOT: tensor.empty()
+  // CHECK-NOT: linalg.fill
+  // CHECK-NOT: tensor.empty()
+  // CHECK-NOT: select
+  %4 = tosa.reduce_max %arg0 {axis = 0 : i32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>) -> tensor<1x4xf32>
+
+  // CHECK: linalg.reduce
+  // CHECK: arith.minimumf
+  // CHECK: arith.cmpf uno
+  // CHECK: arith.select
+  // CHECK: linalg.yield
+  // CHECK: arith.constant 0x7FC00000
+  // CHECK: tensor.empty()
+  // CHECK: linalg.fill
+  // CHECK: tensor.empty()
+  // CHECK: select
+  %5 = tosa.reduce_min %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xf32>) -> tensor<1x4xf32>
+
+  // CHECK: linalg.reduce
+  // CHECK: arith.maximumf
+  // CHECK: arith.cmpf uno
+  // CHECK: arith.select
+  // CHECK: linalg.yield
+  // CHECK: arith.constant 0x7FC00000
+  // CHECK: tensor.empty()
+  // CHECK: linalg.fill
+  // CHECK: tensor.empty()
+  // CHECK: select
+  %6 = tosa.reduce_max %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xf32>) -> tensor<1x4xf32>
+
+  // CHECK: linalg.generic
+  // CHECK: arith.minimumf
+  // CHECK-NOT: arith.cmpf uno
+  // CHECK-NOT: arith.select
+  // CHECK: linalg.yield
+  %7 = tosa.minimum %arg0, %arg1 {nan_mode = "PROPAGATE"} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
+
+  // CHECK: linalg.generic
+  // CHECK: arith.maximumf
+  // CHECK-NOT: arith.cmpf uno
+  // CHECK-NOT: arith.select
+  // CHECK: linalg.yield
+  %8 = tosa.maximum %arg0, %arg1 {nan_mode = "PROPAGATE"} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
+
+  // CHECK: linalg.generic
+  // CHECK: arith.minimumf
+  // CHECK: arith.cmpf uno
+  // CHECK: arith.cmpf uno
+  // CHECK: arith.select
+  // CHECK: arith.select
+  // CHECK: linalg.yield
+  %9 = tosa.minimum %arg0, %arg1 {nan_mode = "IGNORE"} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
+
+  // CHECK: linalg.generic
+  // CHECK: arith.maximumf
+  // CHECK: arith.cmpf uno
+  // CHECK: arith.cmpf uno
+  // CHECK: arith.select
+  // CHECK: arith.select
+  // CHECK: linalg.yield
+  %10 = tosa.maximum %arg0, %arg1 {nan_mode = "IGNORE"} : (tensor<5x4xf32>, tensor<5x4xf32>) -> tensor<5x4xf32>
+
+  // CHECK: linalg.generic
+  // CHECK: arith.cmpf ogt
+  // CHECK: arith.select
+  // CHECK: arith.select
+  // CHECK-NOT: arith.cmpf uno
+  // CHECK-NOT: arith.cmpf uno
+  // CHECK-NOT: arith.select
+  // CHECK-NOT: arith.select
+  // CHECK: linalg.yield
+  %11 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>)  -> tensor<4xi32>
+
+  // CHECK: linalg.generic
+  // CHECK: arith.cmpf ogt
+  // CHECK: arith.select
+  // CHECK: arith.select
+  // CHECK: arith.cmpf uno
+  // CHECK: arith.select
+  // CHECK: arith.select
+  // CHECK: linalg.yield
+  %12 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xf32>)  -> tensor<4xi32>
+
+  // CHECK: linalg.generic
+  // CHECK: arith.minimumf
+  // CHECK: arith.maximumf
+  // CHECK-NOT: arith.cmpf uno
+  // CHECK-NOT: arith.select
+  // CHECK: linalg.yield
+  %13 = tosa.clamp %arg0 {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>) -> tensor<5x4xf32>
+
+  // CHECK: linalg.generic
+  // CHECK: arith.minimumf
+  // CHECK: arith.maximumf
+  // CHECK: arith.cmpf uno
+  // CHECK: arith.select
+  // CHECK: linalg.yield
+  %14 = tosa.clamp %arg0 {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32, nan_mode = "IGNORE"} : (tensor<5x4xf32>) -> tensor<5x4xf32>
+
+  return
+}



More information about the Mlir-commits mailing list