[Mlir-commits] [mlir] 0981dca - [mlir][arith] Add neutral element support to arith.maxnumf/arith.minnumf (#93278)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 29 07:20:53 PDT 2024
Author: donald chen
Date: 2024-05-29T10:20:49-04:00
New Revision: 0981dca7779d4acfcbb92fbb29a7a1033e283b88
URL: https://github.com/llvm/llvm-project/commit/0981dca7779d4acfcbb92fbb29a7a1033e283b88
DIFF: https://github.com/llvm/llvm-project/commit/0981dca7779d4acfcbb92fbb29a7a1033e283b88.diff
LOG: [mlir][arith] Add neutral element support to arith.maxnumf/arith.minnumf (#93278)
For maxnumf and minnumf, the result of calculations involving NaN will
be another value, so their neutral element is set to NaN.
Added:
Modified:
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index a0b50251c6b67..5797c5681a5fd 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2467,6 +2467,12 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
: APFloat::getInf(semantic, /*Negative=*/true);
return builder.getFloatAttr(resultType, identity);
}
+ case AtomicRMWKind::maxnumf: {
+ const llvm::fltSemantics &semantic =
+ llvm::cast<FloatType>(resultType).getFloatSemantics();
+ APFloat identity = APFloat::getNaN(semantic, /*Negative=*/true);
+ return builder.getFloatAttr(resultType, identity);
+ }
case AtomicRMWKind::addf:
case AtomicRMWKind::addi:
case AtomicRMWKind::maxu:
@@ -2489,6 +2495,12 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
return builder.getFloatAttr(resultType, identity);
}
+ case AtomicRMWKind::minnumf: {
+ const llvm::fltSemantics &semantic =
+ llvm::cast<FloatType>(resultType).getFloatSemantics();
+ APFloat identity = APFloat::getNaN(semantic, /*Negative=*/false);
+ return builder.getFloatAttr(resultType, identity);
+ }
case AtomicRMWKind::mins:
return builder.getIntegerAttr(
resultType, APInt::getSignedMaxValue(
@@ -2518,6 +2530,8 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
.Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
.Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
.Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
+ .Case([](arith::MaxNumFOp op) { return AtomicRMWKind::maxnumf; })
+ .Case([](arith::MinNumFOp op) { return AtomicRMWKind::minnumf; })
// Integer operations.
.Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
.Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
diff --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir
index 31e9fd00cffa0..9849f36285b16 100644
--- a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir
@@ -407,3 +407,95 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+// Checks we use nan as the neutral element for maxnumf op.
+func.func @generic_split_maxnumf(%in: tensor<32xf32>, %out: tensor<f32>) -> tensor<f32> {
+ %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> ()>],
+ iterator_types = ["reduction"]}
+ ins(%in : tensor<32xf32>)
+ outs(%out : tensor<f32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %y = arith.maxnumf %arg1, %arg2 : f32
+ linalg.yield %y : f32
+ } -> tensor<f32>
+ return %r : tensor<f32>
+}
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> ()>
+// CHECK-LABEL: func @generic_split_maxnumf
+// The float value 0xFFC00000 that is filled into the init tensor represents negative NaN.
+// CHECK-DAG: %[[ID:.*]] = arith.constant 0xFFC00000 : f32
+// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] output_shape [8, 4] : tensor<32xf32> into tensor<8x4xf32>
+// CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<4xf32>
+// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32>
+// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]}
+// CHECK-SAME: ins(%[[I1]] : tensor<8x4xf32>) outs(%[[F]] : tensor<4xf32>) {
+// CHECK: arith.maxnumf
+// CHECK: linalg.yield
+// CHECK: } -> tensor<4xf32>
+// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["reduction"]}
+// CHECK-SAME: ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor<f32>) {
+// CHECK: arith.maxnumf {{.*}}
+// CHECK: linalg.yield
+// CHECK: } -> tensor<f32>
+// CHECK: return %[[R]] : tensor<f32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 0, inner_parallel}
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+// Checks we use nan as the neutral element for minnumf op.
+func.func @generic_split_minnumf(%in: tensor<32xf32>, %out: tensor<f32>) -> tensor<f32> {
+ %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
+ affine_map<(d0) -> ()>],
+ iterator_types = ["reduction"]}
+ ins(%in : tensor<32xf32>)
+ outs(%out : tensor<f32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %y = arith.minnumf %arg1, %arg2 : f32
+ linalg.yield %y : f32
+ } -> tensor<f32>
+ return %r : tensor<f32>
+}
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> ()>
+// CHECK-LABEL: func @generic_split_minnumf
+// The float value 0x7FC00000 that is filled into the init tensor represents positive NaN.
+// CHECK-DAG: %[[ID:.*]] = arith.constant 0x7FC00000 : f32
+// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] output_shape [8, 4] : tensor<32xf32> into tensor<8x4xf32>
+// CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<4xf32>
+// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32>
+// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]}
+// CHECK-SAME: ins(%[[I1]] : tensor<8x4xf32>) outs(%[[F]] : tensor<4xf32>) {
+// CHECK: arith.minnumf
+// CHECK: linalg.yield
+// CHECK: } -> tensor<4xf32>
+// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]]], iterator_types = ["reduction"]}
+// CHECK-SAME: ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor<f32>) {
+// CHECK: arith.minnumf {{.*}}
+// CHECK: linalg.yield
+// CHECK: } -> tensor<f32>
+// CHECK: return %[[R]] : tensor<f32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 0, inner_parallel}
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list