[Mlir-commits] [mlir] [mlir][ArithToAMDGPU] Add option for saturating truncation to fp8 (PR #74153)

Jakub Kuderski llvmlistbot at llvm.org
Tue Jan 9 20:36:25 PST 2024


================
@@ -127,6 +130,60 @@ static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
   llvm_unreachable("The only 32-bit float type is f32");
 }
 
+static Value getMaybeVectorConstant(PatternRewriter &rewriter, Location loc,
+                                    const APFloat &value, Type type) {
+  if (isa<FloatType>(type))
+    return rewriter.createOrFold<arith::ConstantOp>(
+        loc, type, rewriter.getFloatAttr(type, value));
+  TypedAttr splat = DenseElementsAttr::get(cast<ShapedType>(type), value);
+  return rewriter.createOrFold<arith::ConstantOp>(loc, type, splat);
+}
+
+// If `in` is a finite value, clamp it between the maximum and minimum values
+// of `outElemType` so that subsequent conversion instructions don't
+// overflow those out-of-range values to NaN. These semantics are commonly
+// used in machine-learning contexts where failure to clamp would lead to
+// excessive NaN production.
+static Value clampInput(PatternRewriter &rewriter, Location loc,
+                        Type outElemType, Value source) {
+  Type sourceType = source.getType();
+  const llvm::fltSemantics &sourceSem =
+      cast<FloatType>(getElementTypeOrSelf(sourceType)).getFloatSemantics();
+  const llvm::fltSemantics &targetSem =
+      cast<FloatType>(outElemType).getFloatSemantics();
+
+  APFloat min = APFloat::getLargest(targetSem, /*Negative=*/true);
+  APFloat max = APFloat::getLargest(targetSem, /*Negative=*/false);
+  bool ignoredLosesInfo = false;
+  // We can ignore conversion failures here because this conversion promotes
+  // from a smaller type to a larger one.
+  (void)min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
+  (void)max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
+
+  Value minCst = getMaybeVectorConstant(rewriter, loc, min, sourceType);
+  Value maxCst = getMaybeVectorConstant(rewriter, loc, max, sourceType);
+
+  Value inf = getMaybeVectorConstant(
+      rewriter, loc, APFloat::getInf(sourceSem, /*Negative=*/false),
+      sourceType);
+  Value negInf = getMaybeVectorConstant(
+      rewriter, loc, APFloat::getInf(sourceSem, /*Negative=*/true), sourceType);
----------------
kuhar wrote:

Is the source type guaranteed to support infinities? I'm wondering if there are some existing / future edge cases where sourceType ends up being E4M3.

https://github.com/llvm/llvm-project/pull/74153


More information about the Mlir-commits mailing list