[Mlir-commits] [mlir] [mlir][tosa][tosa-to-linalg] Add NaN Mode Lowering (PR #125668)

Jack Frankland llvmlistbot at llvm.org
Tue Feb 25 03:00:11 PST 2025


================
@@ -724,11 +724,44 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
       rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(
           op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
           filledEmptyTensor, strideAttr, dilationAttr);
-    } else {
-      rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
-          op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
-          filledEmptyTensor, strideAttr, dilationAttr);
+      return llvm::success();
     }
+
+    auto resultOp = rewriter.create<linalg::PoolingNhwcMaxOp>(
+        op->getLoc(), ArrayRef<Type>{resultTy},
+        ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr,
+        dilationAttr);
+
+    rewriter.replaceOp(op, resultOp);
+    // "PROPAGATE" mode matches the behaviour of the LinAlg named op, so no
+    // compare and select materialization is required.
+    //
+    // In the case of "IGNORE" we need to insert a compare and select. Since
+    // we've already produced a named op we will just take its body and modify
+    // it to include the appropriate checks. If the current value is NaN the
+    // old value of pool will be taken otherwise we use the result.
+    if (const auto nanMode = op.getNanMode(); nanMode == "IGNORE") {
+      auto genericOp = rewriter.create<linalg::GenericOp>(
+          op->getLoc(), resultOp.getType(0), resultOp.getInputs(),
+          resultOp.getOutputs(), resultOp.getIndexingMapsArray(),
+          resultOp.getIteratorTypesArray(),
+          [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
+            IRMapping map;
+            auto oldBlock = resultOp.getRegion().begin();
+            auto oldArgs = oldBlock->getArguments();
+            auto &oldMaxOp = *resultOp.getBlock()->begin();
+            map.map(oldArgs, blockArgs);
+            auto *newOp = opBuilder.clone(oldMaxOp, map);
+            Value isNaN = opBuilder.create<arith::CmpFOp>(
----------------
FranklandJack wrote:

This follows the same pattern as clamp in that it is a single compare and single select however we only do this materialization if we are in IGNORE mode in which case we've already had to build a generic. I suppose we could introduce some new function called something like `buildUnaryNanIgnore` and call this inside `materializeBinaryNanCheckIfRequired` and also here but I think it pays to be explicit here.

https://github.com/llvm/llvm-project/pull/125668


More information about the Mlir-commits mailing list