[Mlir-commits] [mlir] 36c3466 - [mlir][linalg] Fix neutral elt for softmax (#118952)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jan 12 23:21:11 PST 2025
Author: Clément Fournier
Date: 2025-01-13T15:21:07+08:00
New Revision: 36c3466aef6c8bfde0ddc736b8403e2c45f5e1c6
URL: https://github.com/llvm/llvm-project/commit/36c3466aef6c8bfde0ddc736b8403e2c45f5e1c6
DIFF: https://github.com/llvm/llvm-project/commit/36c3466aef6c8bfde0ddc736b8403e2c45f5e1c6.diff
LOG: [mlir][linalg] Fix neutral elt for softmax (#118952)
The decomposition of `linalg.softmax` uses `maxnumf`, but the identity
element that is used in the generated code is the one for `maximumf`.
They are not the same, as the identity for `maxnumf` is `NaN`, while the
one of `maximumf` is `-Infty`. This is wrong and prevents the maxnumf
from being folded.
Related to #114595, which fixed the folder for maxnumf.
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/transform-op-decompose.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 8973e87c063b33..c13b663dbf05b1 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2890,7 +2890,7 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
dims.erase(dims.begin() + reductionDim);
// Step 1: Compute max along dim.
Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
- Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
+ Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf,
elementType, b, loc,
/*useOnlyFiniteValue=*/true);
Value neutralForMaxFInit =
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 2e211d2fa7dbe9..72acf43361f501 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -210,7 +210,7 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
// CHECK-LABEL: func.func @softmax(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
// CHECK-DAG: %[[D1:.+]] = tensor.empty() : tensor<2x16xf32>
-// CHECK-DAG: %[[CST:.+]] = arith.constant -3.40282347E+38 : f32
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0xFFC00000 : f32
// CHECK: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
// CHECK: %[[D3:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
// CHECK-SAME: "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) {
More information about the Mlir-commits
mailing list