[Mlir-commits] [mlir] Introduce `arith.scaling_extf` and `arith.scaling_truncf` (PR #141965)
Umang Yadav
llvmlistbot at llvm.org
Mon Jun 2 08:08:17 PDT 2025
================
@@ -1280,6 +1333,66 @@ def Arith_TruncFOp :
attr-dict `:` type($in) `to` type($out) }];
}
+//===----------------------------------------------------------------------===//
+// Scaling TruncFOp
+//===----------------------------------------------------------------------===//
+
+def Arith_ScalingTruncFOp
+ : Arith_Op<"scaling_truncf",
+ [Pure, SameInputOutputTensorDims,
+ DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
+ DeclareOpInterfaceMethods<ArithFastMathInterface>,
+ DeclareOpInterfaceMethods<CastOpInterface>]>,
+ Arguments<(ins FloatLike:$in, FloatLike:$scale,
+ OptionalAttr<Arith_RoundingModeAttr>:$roundingmode,
+ OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
+ Results<(outs FloatLike:$out)> {
+ let summary =
+ "Downcasts input floating point values using provided scales values following OCP MXFP Spec";
+ let description = [{
+ This operation quantizes input using the provided scale values. It expects
+ both scales and the input operand to be of the same shape and, therefore,
+ makes the operation elementwise. Scales are usually calculated per block
+ following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
+
+ If scales are calculated per block where blockSize != 1, scales may require
+ broadcasting to make this operation elementwise. For example, let's say the
+ input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
+ assuming quantization happens on the last axis, the input can be reshaped to
+ `<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
+ per block on the last axis. Therefore, scales will be of shape
+ `<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
+ shape as long as it is broadcast compatible with the input, e.g.,
+ `<1 x 1 x ... (dimN/blockSize) x 1>`.
+
+ In this example, before calling into `arith.scaling_truncf`, scales must be
+ broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
+ that there could be multiple quantization axes. Internally,
+ `arith.scaling_truncf` would perform the following:
+
+ ```
+ scaleETy = get_type(scale)
+ inputETy = get_type(input)
+ resultETy = get_type(result)
+ // prepare Scale values with normalization and clamping
+ scale.exponent = arith.truncf(scale.bcst) : scaleETy to f8E8M0
+ scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputETy
+ // emax is calculated as exponent of the largest normal value in quantized type.
+ scale.normalize = arith.divf(scale.extf, emax)
+ scale.clamped = clamp(scale.normalize) // clamp underflows
+ input.flused = flush_denorms(input)
----------------
umangyadav wrote:
IMO, That would be more details than necessary.
https://github.com/llvm/llvm-project/pull/141965
More information about the Mlir-commits
mailing list