[Mlir-commits] [mlir] [mlir][ArithToAMDGPU] Add option for saturating truncation to fp8 (PR #74153)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 1 15:04:52 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-amdgpu

@llvm/pr-subscribers-mlir

Author: Krzysztof Drewniak (krzysz00)

<details>
<summary>Changes</summary>

Many machine-learning applications (and most software written at AMD) expect the operation that truncates floats to 8-bit floats to be saturatinng. That is, they expect `truncf 256.0 : f32 to f8E4M3FNUZ` to yield `240.0`, not `NaN`, and similarly for negative numbers. However, the underlying hardware instruction that can be used for this truncation implements overflow-to-NaN semantics.

To enable handling this usecase, we add the saturate-fp8-truncf option to ArithToAMDGPU (off by default), which causes the requisite clamping code to be emitted. Said clamping code ensures that Inf and NaN are passed through exactly (and thus trancate to NaN).

---
Full diff: https://github.com/llvm/llvm-project/pull/74153.diff


4 Files Affected:

- (modified) mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h (+2-1) 
- (modified) mlir/include/mlir/Conversion/Passes.td (+6) 
- (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+65-5) 
- (added) mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir (+57) 


``````````diff
diff --git a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
index 7f445fee5ba6b82..a1c059800752aca 100644
--- a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
+++ b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
@@ -20,7 +20,8 @@ class Pass;
 #include "mlir/Conversion/Passes.h.inc"
 
 namespace arith {
-void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns);
+void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns,
+                                             bool saturateFP8Truncf);
 } // namespace arith
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 06756ff3df0bb3b..2aa2ad634aeb722 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -125,6 +125,12 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
   }];
 
   let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"];
+
+  let options = [
+    Option<"saturateFP8Truncf", "saturate-fp8-truncf", "bool",
+           /*default=*/"false",
+           "Whether truncation to 8-bit float types should be saturating">,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 7785405eae67be3..d6b916e6e55423d 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -44,7 +44,10 @@ struct ExtfOnFloat8RewritePattern final
 
 struct TruncfToFloat8RewritePattern final
     : public OpRewritePattern<arith::TruncFOp> {
-  using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
+  bool saturateFP8 = false;
+  TruncfToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8)
+      : OpRewritePattern<arith::TruncFOp>::OpRewritePattern(ctx),
+        saturateFP8(saturateFP8) {}
 
   LogicalResult match(arith::TruncFOp op) const override;
   void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
@@ -127,6 +130,60 @@ static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
   llvm_unreachable("The only 32-bit float type is f32");
 }
 
+static Value getMaybeVectorConstant(PatternRewriter &rewriter, Location loc,
+                                    const APFloat &value, Type type) {
+  if (isa<FloatType>(type))
+    return rewriter.createOrFold<arith::ConstantOp>(
+        loc, type, rewriter.getFloatAttr(type, value));
+  TypedAttr splat = DenseElementsAttr::get(cast<ShapedType>(type), value);
+  return rewriter.createOrFold<arith::ConstantOp>(loc, type, splat);
+}
+
+// If `in` is a finite value, clamp it between the maximum and minimum values
+// of `outElemType` so that subsequent conversion instructions don't
+// overflow those out-of-range values to NaN. These semantics are commonly
+// used in machine-learning contexts where failure to clamp would lead to
+// excessive NaN production.
+static Value clampInput(PatternRewriter &rewriter, Location loc,
+                        Type outElemType, Value source) {
+  Type sourceType = source.getType();
+  const llvm::fltSemantics &sourceSem =
+      cast<FloatType>(getElementTypeOrSelf(sourceType)).getFloatSemantics();
+  const llvm::fltSemantics &targetSem =
+      cast<FloatType>(outElemType).getFloatSemantics();
+
+  APFloat min = APFloat::getLargest(targetSem, /*Negative=*/true);
+  APFloat max = APFloat::getLargest(targetSem, /*Negative=*/false);
+  bool ignoredLosesInfo = false;
+  // We can ignore conversion failures here because this conversion promotes
+  // from a smaller type to a larger one.
+  (void)min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
+  (void)max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
+
+  Value minCst = getMaybeVectorConstant(rewriter, loc, min, sourceType);
+  Value maxCst = getMaybeVectorConstant(rewriter, loc, max, sourceType);
+
+  Value inf = getMaybeVectorConstant(
+      rewriter, loc, APFloat::getInf(sourceSem, /*Negative=*/false),
+      sourceType);
+  Value negInf = getMaybeVectorConstant(
+      rewriter, loc, APFloat::getInf(sourceSem, /*Negative=*/true), sourceType);
+  Value isInf = rewriter.createOrFold<arith::CmpFOp>(
+      loc, arith::CmpFPredicate::OEQ, source, inf);
+  Value isNegInf = rewriter.createOrFold<arith::CmpFOp>(
+      loc, arith::CmpFPredicate::OEQ, source, negInf);
+  Value isNan = rewriter.createOrFold<arith::CmpFOp>(
+      loc, arith::CmpFPredicate::UNO, source, source);
+  Value isNonFinite = rewriter.create<arith::OrIOp>(
+      loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan);
+
+  Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst);
+  Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst);
+  Value res =
+      rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped);
+  return res;
+}
+
 LogicalResult TruncfToFloat8RewritePattern::match(arith::TruncFOp op) const {
   Type outType = op.getOut().getType();
   if (auto outVecType = outType.dyn_cast<VectorType>()) {
@@ -145,6 +202,8 @@ void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op,
   Location loc = op.getLoc();
   Value in = op.getIn();
   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
+  if (saturateFP8)
+    in = clampInput(rewriter, loc, outElemType, in);
   VectorType truncResType = VectorType::get(4, outElemType);
   if (!in.getType().isa<VectorType>()) {
     Value asFloat = castToF32(in, loc, rewriter);
@@ -196,15 +255,16 @@ void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op,
 }
 
 void mlir::arith::populateArithToAMDGPUConversionPatterns(
-    RewritePatternSet &patterns) {
-  patterns.add<ExtfOnFloat8RewritePattern, TruncfToFloat8RewritePattern>(
-      patterns.getContext());
+    RewritePatternSet &patterns, bool saturateFP8Truncf) {
+  patterns.add<ExtfOnFloat8RewritePattern>(patterns.getContext());
+  patterns.add<TruncfToFloat8RewritePattern>(patterns.getContext(),
+                                             saturateFP8Truncf);
 }
 
 void ArithToAMDGPUConversionPass::runOnOperation() {
   Operation *op = getOperation();
   RewritePatternSet patterns(op->getContext());
-  arith::populateArithToAMDGPUConversionPatterns(patterns);
+  arith::populateArithToAMDGPUConversionPatterns(patterns, saturateFP8Truncf);
   if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
     return signalPassFailure();
 }
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir
new file mode 100644
index 000000000000000..d0c2cd4090117ff
--- /dev/null
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt --split-input-file %s \
+// RUN: --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{saturate-fp8-truncf=true}))' \
+// RUN: | FileCheck %s
+
+// CHECK-LABEL: func.func @scalar_trunc
+// CHECK-SAME: ([[V:%.+]]: f16)
+// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
+// CHECK-DAG: [[CMin:%.+]] = arith.constant -5.734400e+04 : f16
+// CHECK-DAG: [[CMax:%.+]] = arith.constant 5.734400e+04 : f16
+// CHECK-DAG: [[CInf:%.+]] = arith.constant 0x7C00 : f16
+// CHECK-DAG: [[CNegInf:%.+]] = arith.constant 0xFC00 : f16
+// CHECK: [[ISINF:%.+]] = arith.cmpf oeq, [[V]], [[CInf]]
+// CHECK: [[ISNEGINF:%.+]] = arith.cmpf oeq, [[V]], [[CNegInf]]
+// CHECK: [[ISNAN:%.+]] = arith.cmpf uno, [[V]], [[V]]
+// CHECK: [[ISNONFINITE_1:%.+]] = arith.ori [[ISINF]], [[ISNEGINF]]
+// CHECK: [[ISNONFINITE:%.+]] = arith.ori [[ISNONFINITE_1]], [[ISNAN]]
+// CHECK: [[CLAMPEDBELOW:%.+]] = arith.maximumf [[V]], [[CMin]]
+// CHECK: [[CLAMPED:%.+]] = arith.minimumf [[CLAMPEDBELOW]], [[CMax]]
+// CHECK: [[SATURATED:%.+]] = arith.select [[ISNONFINITE]], [[V]], [[CLAMPED]]
+// CHECK: [[FLOAT:%.+]] = arith.extf [[SATURATED]] : f16 to f32
+// CHECK: [[TRUNCV:%.+]] = amdgpu.packed_trunc_2xfp8 [[FLOAT]], undef into undef[word 0] : f32 to vector<4xf8E5M2FNUZ>
+// CHECK: [[W:%.+]] = vector.extractelement [[TRUNCV]]{{\[}}[[C0]] : index] : vector<4xf8E5M2FNUZ>
+// CHECK: return [[W]] : f8E5M2FNUZ
+func.func @scalar_trunc(%v: f16) -> f8E5M2FNUZ {
+  %w = arith.truncf %v : f16 to f8E5M2FNUZ
+  return %w : f8E5M2FNUZ
+}
+
+// No 0-D test because arith.truncf hasn't been extended to support it.
+
+// -----
+
+// CHECK-LABEL: func.func @vector_trunc
+// CHECK-SAME: ([[V:%.+]]: vector<2xf32>) -> vector<2xf8E4M3FNUZ> {
+// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
+// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index
+// CHECK-DAG: [[CMin:%.+]] = arith.constant dense<-2.400000e+02> : vector<2xf32>
+// CHECK-DAG: [[CMax:%.+]] = arith.constant dense<2.400000e+02> : vector<2xf32>
+// CHECK-DAG: [[CInf:%.+]] = arith.constant dense<0x7F800000> : vector<2xf32>
+// CHECK-DAG: [[CNegInf:%.+]] = arith.constant dense<0xFF800000> : vector<2xf32>
+// CHECK: [[ISINF:%.+]] = arith.cmpf oeq, [[V]], [[CInf]]
+// CHECK: [[ISNEGINF:%.+]] = arith.cmpf oeq, [[V]], [[CNegInf]]
+// CHECK: [[ISNAN:%.+]] = arith.cmpf uno, [[V]], [[V]]
+// CHECK: [[ISNONFINITE_1:%.+]] = arith.ori [[ISINF]], [[ISNEGINF]]
+// CHECK: [[ISNONFINITE:%.+]] = arith.ori [[ISNONFINITE_1]], [[ISNAN]]
+// CHECK: [[CLAMPEDBELOW:%.+]] = arith.maximumf [[V]], [[CMin]]
+// CHECK: [[CLAMPED:%.+]] = arith.minimumf [[CLAMPEDBELOW]], [[CMax]]
+// CHECK: [[SATURATED:%.+]] = arith.select [[ISNONFINITE]], [[V]], [[CLAMPED]]
+// CHECK: [[F0:%.+]] = vector.extractelement [[SATURATED]]{{\[}}[[C0]] : index]
+// CHECK: [[F1:%.+]] = vector.extractelement [[SATURATED]]{{\[}}[[C1]] : index]
+// CHECK: [[W0:%.+]] = amdgpu.packed_trunc_2xfp8 [[F0]], [[F1]] into undef[word 0] : f32 to vector<4xf8E4M3FNUZ>
+// CHECK: [[W:%.+]] = vector.extract_strided_slice [[W0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E4M3FNUZ> to vector<2xf8E4M3FNUZ>
+// CHECK: return [[W]] : vector<2xf8E4M3FNUZ>
+func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf8E4M3FNUZ> {
+  %w = arith.truncf %v : vector<2xf32> to vector<2xf8E4M3FNUZ>
+  return %w : vector<2xf8E4M3FNUZ>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/74153


More information about the Mlir-commits mailing list