[Mlir-commits] [mlir] [mlir][linalg] Fix neutral elt for softmax (PR #118952)

Clément Fournier llvmlistbot at llvm.org
Fri Dec 6 02:32:58 PST 2024


https://github.com/oowekyala created https://github.com/llvm/llvm-project/pull/118952

The decomposition of `linalg.softmax` uses `maxnumf`, but the identity element that is used in the generated code is the one for `maximumf`. They are not the same, as the identity for `maxnumf` is `NaN`, while the one of `maximumf` is `-Infty`. This is wrong and prevents the maxnumf from being folded. 

Related to #114595, which fixed the folder for maxnumf.

>From 98506b2aa61cf607e390e87777af33ab47c260c9 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Fri, 1 Nov 2024 20:13:38 +0100
Subject: [PATCH] Fix neutral elt for softmax

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp             | 2 +-
 mlir/test/Dialect/Linalg/transform-op-decompose.mlir | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index d9840e3923c4f7..133855cc389338 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2892,7 +2892,7 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
   dims.erase(dims.begin() + reductionDim);
   // Step 1: Compute max along dim.
   Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
-  Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
+  Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf,
                                                  elementType, b, loc,
                                                  /*useOnlyFiniteValue=*/true);
   Value neutralForMaxFInit =
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 2e211d2fa7dbe9..72acf43361f501 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -210,7 +210,7 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
 // CHECK-LABEL:      func.func @softmax(
 // CHECK-SAME:           %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
 // CHECK-DAG:        %[[D1:.+]] = tensor.empty() : tensor<2x16xf32>
-// CHECK-DAG:        %[[CST:.+]] = arith.constant -3.40282347E+38 : f32
+// CHECK-DAG:        %[[CST:.+]] = arith.constant 0xFFC00000 : f32
 // CHECK:        %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
 // CHECK:        %[[D3:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
 // CHECK-SAME:     "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) {



More information about the Mlir-commits mailing list