[Mlir-commits] [mlir] [mlir][ArithToAMDGPU] Add option for saturating truncation to fp8 (PR #74153)
Jakub Kuderski
llvmlistbot at llvm.org
Tue Jan 9 20:36:25 PST 2024
================
@@ -127,6 +130,60 @@ static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
llvm_unreachable("The only 32-bit float type is f32");
}
+static Value getMaybeVectorConstant(PatternRewriter &rewriter, Location loc,
+ const APFloat &value, Type type) {
+ if (isa<FloatType>(type))
+ return rewriter.createOrFold<arith::ConstantOp>(
+ loc, type, rewriter.getFloatAttr(type, value));
+ TypedAttr splat = DenseElementsAttr::get(cast<ShapedType>(type), value);
+ return rewriter.createOrFold<arith::ConstantOp>(loc, type, splat);
+}
+
+// If `in` is a finite value, clamp it between the maximum and minimum values
+// of `outElemType` so that subsequent conversion instructions don't
+// overflow those out-of-range values to NaN. These semantics are commonly
+// used in machine-learning contexts where failure to clamp would lead to
+// excessive NaN production.
+static Value clampInput(PatternRewriter &rewriter, Location loc,
+ Type outElemType, Value source) {
+ Type sourceType = source.getType();
+ const llvm::fltSemantics &sourceSem =
+ cast<FloatType>(getElementTypeOrSelf(sourceType)).getFloatSemantics();
+ const llvm::fltSemantics &targetSem =
+ cast<FloatType>(outElemType).getFloatSemantics();
+
+ APFloat min = APFloat::getLargest(targetSem, /*Negative=*/true);
+ APFloat max = APFloat::getLargest(targetSem, /*Negative=*/false);
+ bool ignoredLosesInfo = false;
+ // We can ignore conversion failures here because this conversion promotes
+ // from a smaller type to a larger one.
+ (void)min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
+ (void)max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
+
+ Value minCst = getMaybeVectorConstant(rewriter, loc, min, sourceType);
+ Value maxCst = getMaybeVectorConstant(rewriter, loc, max, sourceType);
+
+ Value inf = getMaybeVectorConstant(
+ rewriter, loc, APFloat::getInf(sourceSem, /*Negative=*/false),
+ sourceType);
+ Value negInf = getMaybeVectorConstant(
+ rewriter, loc, APFloat::getInf(sourceSem, /*Negative=*/true), sourceType);
----------------
kuhar wrote:
Is the source type guaranteed to support infinities? I'm wondering if there are some existing / future edge cases where sourceType ends up being E4M3.
https://github.com/llvm/llvm-project/pull/74153
More information about the Mlir-commits
mailing list