[Mlir-commits] [mlir] f559e73 - [mlir] Support fast-math friendly constants for identity value

Quentin Colombet llvmlistbot at llvm.org
Wed Aug 9 05:28:12 PDT 2023


Author: Quentin Colombet
Date: 2023-08-09T14:22:18+02:00
New Revision: f559e73fad5bf6991411fa13a95ec6112745b8cf

URL: https://github.com/llvm/llvm-project/commit/f559e73fad5bf6991411fa13a95ec6112745b8cf
DIFF: https://github.com/llvm/llvm-project/commit/f559e73fad5bf6991411fa13a95ec6112745b8cf.diff

LOG: [mlir] Support fast-math friendly constants for identity value

Add an option to the family of `getIdentity` helper functions so that it is
possible to produce fast-math friendly constants.

For instance, for maxf the identity value is `-inf`, however, if the related
operations are lowered with fast-math (`noinf` in particular), then the value
becomes `poison` and chances are the whole codegen is not going to do what we
want.

To avoid this problem, we add an option to `getIdentity` and friends that
specifies whether a finite value needs to be produced or not.

The patch is NFC for all the code but the lowering of `linalg::softmax`
because we know we lower that with fast-math down the line.

I didn't audit the rest of the code to check if it would make sense to set
this boolean in more places.

Note: It feels kind of wrong to have to know what the lowering may do, but
I don't know what the right (at least short-term) solution is. Long term,
we may want a special "neutral element" attribute for the respective ops. I
didn't think too much about the implications for that.

Differential Revision: https://reviews.llvm.org/D156471

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arith/IR/Arith.h
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/transform-op-decompose.mlir
    mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
index 0abafa1d4c834b..971c78f4a86a75 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h
+++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
@@ -122,16 +122,28 @@ bool applyCmpPredicate(arith::CmpFPredicate predicate, const APFloat &lhs,
                        const APFloat &rhs);
 
 /// Returns the identity value attribute associated with an AtomicRMWKind op.
+/// `useOnlyFiniteValue` defines whether the identity value should steer away
+/// from infinity representations or anything that is not a proper finite
+/// number.
+/// E.g., The identity value for maxf is in theory `-Inf`, but if we want to
+/// stay in the finite range, it would be `BiggestRepresentableNegativeFloat`.
+/// The purpose of this boolean is to offer constants that will play nice
+/// with fast math related optimizations.
 TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
-                               OpBuilder &builder, Location loc);
+                               OpBuilder &builder, Location loc,
+                               bool useOnlyFiniteValue = false);
 
 /// Return the identity numeric value associated to the give op. Return
 /// std::nullopt if there is no known neutral element.
+/// If `op` has `FastMathFlags::ninf`, only finite values will be used
+/// as neutral element.
 std::optional<TypedAttr> getNeutralElement(Operation *op);
 
 /// Returns the identity value associated with an AtomicRMWKind op.
+/// \see getIdentityValueAttr for a description of what `useOnlyFiniteValue`
+/// does.
 Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder,
-                       Location loc);
+                       Location loc, bool useOnlyFiniteValue = false);
 
 /// Returns the value obtained by applying the reduction operation kind
 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 482198771c3f6a..44a3217714dcbf 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2387,13 +2387,17 @@ OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
 
 /// Returns the identity value attribute associated with an AtomicRMWKind op.
 TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
-                                            OpBuilder &builder, Location loc) {
+                                            OpBuilder &builder, Location loc,
+                                            bool useOnlyFiniteValue) {
   switch (kind) {
-  case AtomicRMWKind::maxf:
-    return builder.getFloatAttr(
-        resultType,
-        APFloat::getInf(llvm::cast<FloatType>(resultType).getFloatSemantics(),
-                        /*Negative=*/true));
+  case AtomicRMWKind::maxf: {
+    const llvm::fltSemantics &semantic =
+        llvm::cast<FloatType>(resultType).getFloatSemantics();
+    APFloat identity = useOnlyFiniteValue
+                           ? APFloat::getSmallest(semantic, /*Negative=*/true)
+                           : APFloat::getInf(semantic, /*Negative=*/true);
+    return builder.getFloatAttr(resultType, identity);
+  }
   case AtomicRMWKind::addf:
   case AtomicRMWKind::addi:
   case AtomicRMWKind::maxu:
@@ -2407,11 +2411,15 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
     return builder.getIntegerAttr(
         resultType, APInt::getSignedMinValue(
                         llvm::cast<IntegerType>(resultType).getWidth()));
-  case AtomicRMWKind::minf:
-    return builder.getFloatAttr(
-        resultType,
-        APFloat::getInf(llvm::cast<FloatType>(resultType).getFloatSemantics(),
-                        /*Negative=*/false));
+  case AtomicRMWKind::minf: {
+    const llvm::fltSemantics &semantic =
+        llvm::cast<FloatType>(resultType).getFloatSemantics();
+    APFloat identity = useOnlyFiniteValue
+                           ? APFloat::getLargest(semantic, /*Negative=*/false)
+                           : APFloat::getInf(semantic, /*Negative=*/false);
+
+    return builder.getFloatAttr(resultType, identity);
+  }
   case AtomicRMWKind::mins:
     return builder.getIntegerAttr(
         resultType, APInt::getSignedMaxValue(
@@ -2457,17 +2465,28 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
     return std::nullopt;
   }
 
+  bool useOnlyFiniteValue = false;
+  auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
+  if (fmfOpInterface) {
+    arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
+    useOnlyFiniteValue =
+        bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
+  }
+
   // Builder only used as helper for attribute creation.
   OpBuilder b(op->getContext());
   Type resultType = op->getResult(0).getType();
 
-  return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc());
+  return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(),
+                              useOnlyFiniteValue);
 }
 
 /// Returns the identity value associated with an AtomicRMWKind op.
 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
-                                    OpBuilder &builder, Location loc) {
-  auto attr = getIdentityValueAttr(op, resultType, builder, loc);
+                                    OpBuilder &builder, Location loc,
+                                    bool useOnlyFiniteValue) {
+  auto attr =
+      getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue);
   return builder.create<arith::ConstantOp>(loc, attr);
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 7cee8c616a239d..b59ea45e03240b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2492,7 +2492,8 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
   // Step 1: Compute max along dim.
   Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
   Value neutralForMaxF =
-      arith::getIdentityValue(arith::AtomicRMWKind::maxf, elementType, b, loc);
+      arith::getIdentityValue(arith::AtomicRMWKind::maxf, elementType, b, loc,
+                              /*useOnlyFiniteValue=*/true);
   Value neutralForMaxFInit =
       b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
           .result();
@@ -2503,8 +2504,8 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
   Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
 
   // Step 3: Compute sum along dim.
-  Value zero =
-      arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType, b, loc);
+  Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
+                                       b, loc, /*useOnlyFiniteValue=*/true);
   Value zeroInit =
       b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
   Value denominator =

diff  --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 30a155e28a966a..a8520b45275bb2 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -210,7 +210,7 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> ten
 // CHECK-LABEL:      func.func @softmax(
 // CHECK-SAME:           %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
 // CHECK-DAG:        %[[D1:.+]] = tensor.empty() : tensor<2x16xf32>
-// CHECK-DAG:        %[[CST:.+]] = arith.constant 0xFF800000 : f32
+// CHECK-DAG:        %[[CST:.+]] = arith.constant -1.401300e-45 : f32
 // CHECK:        %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32>
 // CHECK:        %[[D3:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel",
 // CHECK-SAME:     "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) {

diff  --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir
index 9d8f2ed5640ba4..783632cc73f623 100644
--- a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir
@@ -141,6 +141,62 @@ transform.sequence failures(propagate) {
 
 // -----
 
+// Check that we don't use -inf as the neutral element for maxf when maxf has
+// ninf. Instead check that we use the smallest finite floating point value.
+// Also check that the fastmath flags are set on the created maxf
+// instructions.
+func.func @generic_split_3d_ninf(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>, %output: tensor<5x2xf32>)
+  -> tensor<5x2xf32>
+{
+  %0 = linalg.generic {
+      indexing_maps = [
+        affine_map<(d0, d1, d2) -> (d1, d0)>,
+        affine_map<(d0, d1, d2) -> (d2, d1)>,
+        affine_map<(d0, d1, d2) -> (d2, d0)>
+      ],
+      iterator_types = ["parallel", "reduction", "parallel"]
+    } ins(%input, %input_2 : tensor<32x2xf32>, tensor<5x32xf32>) outs(%output : tensor<5x2xf32>) {
+    ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
+      %3 = arith.addf %arg0, %arg1 : f32
+      %4 = arith.maxf %3, %arg2 fastmath<nnan,ninf> : f32
+      linalg.yield %4 : f32
+    } -> tensor<5x2xf32>
+  return %0 : tensor<5x2xf32>
+}
+
+//  CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d0)>
+//  CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)>
+//  CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)>
+//  CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//  CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL:  func @generic_split_3d_ninf
+//  CHECK-DAG: %[[ID:.*]] = arith.constant -1.401300e-45 : f32
+//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<4x8x2xf32>
+//  CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x4x8xf32>
+//  CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32>
+//      CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32>
+//      CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]}
+// CHECK-SAME:   ins(%[[I1]], %[[I2]] : tensor<4x8x2xf32>, tensor<5x4x8xf32>) outs(%[[F]] : tensor<5x2x4xf32>) {
+//      CHECK:   arith.addf
+//      CHECK:   arith.maxf {{.*}} fastmath<nnan,ninf>
+//      CHECK:   linalg.yield
+//      CHECK: } -> tensor<5x2x4xf32>
+//      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-SAME:   ins(%[[G]] : tensor<5x2x4xf32>) outs(%{{.*}} : tensor<5x2xf32>) {
+//      CHECK:   arith.maxf {{.*}} fastmath<nnan,ninf>
+//      CHECK:   linalg.yield
+//      CHECK:  } -> tensor<5x2xf32>
+//      CHECK: return %[[R]] : tensor<5x2xf32>
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2}
+    : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+}
+
+// -----
+
 func.func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
   %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>)
                     outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
@@ -279,3 +335,59 @@ transform.sequence failures(propagate) {
   %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2, inner_parallel}
     : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
 }
+
+// -----
+
+// Check that we don't use +inf as the neutral element for minf when minf has
+// ninf. Instead check that we use the largest finite floating point value.
+// Also check that the fastmath flags are set on the created minf
+// instructions.
+func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>, %output: tensor<5x2xf32>)
+  -> tensor<5x2xf32>
+{
+  %0 = linalg.generic {
+      indexing_maps = [
+        affine_map<(d0, d1, d2) -> (d1, d0)>,
+        affine_map<(d0, d1, d2) -> (d2, d1)>,
+        affine_map<(d0, d1, d2) -> (d2, d0)>
+      ],
+      iterator_types = ["parallel", "reduction", "parallel"]
+    } ins(%input, %input_2 : tensor<32x2xf32>, tensor<5x32xf32>) outs(%output : tensor<5x2xf32>) {
+    ^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
+      %3 = arith.addf %arg0, %arg1 : f32
+      %4 = arith.minf %3, %arg2 fastmath<ninf> : f32
+      linalg.yield %4 : f32
+    } -> tensor<5x2xf32>
+  return %0 : tensor<5x2xf32>
+}
+
+//  CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d0)>
+//  CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>
+//  CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)>
+//  CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+//  CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL:  func @generic_split_3d
+//  CHECK-DAG: %[[ID:.*]] = arith.constant 3.40282347E+38 : f32
+//  CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<8x4x2xf32>
+//  CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x8x4xf32>
+//  CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32>
+//      CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32>
+//      CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]}
+// CHECK-SAME:   ins(%[[I1]], %[[I2]] : tensor<8x4x2xf32>, tensor<5x8x4xf32>) outs(%[[F]] : tensor<5x2x4xf32>) {
+//      CHECK:   arith.addf
+//      CHECK:   arith.minf {{.*}} fastmath<ninf>
+//      CHECK:   linalg.yield
+//      CHECK: } -> tensor<5x2x4xf32>
+//      CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]}
+// CHECK-SAME:   ins(%[[G]] : tensor<5x2x4xf32>) outs(%{{.*}} : tensor<5x2xf32>) {
+//      CHECK:   arith.minf {{.*}} fastmath<ninf>
+//      CHECK:   linalg.yield
+//      CHECK:  } -> tensor<5x2xf32>
+//      CHECK: return %[[R]] : tensor<5x2xf32>
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2, inner_parallel}
+    : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+}


        


More information about the Mlir-commits mailing list